pytorch交叉熵损失函数是什么,怎样用
Admin 2022-07-26 群英技术资讯 625 次浏览
必须将权重也转为Tensor的cuda格式;
将该class_weight作为交叉熵函数对应参数的输入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
补充:关于pytorch的CrossEntropyLoss的weight参数
你可以试试下面代码
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,1,1]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.4803)
这里的手动计算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.6075)
手算发现,并不是单纯的那权重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
发现了么,加权后,除以的是权重的和,不是数目的和。
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5]) outputs = torch.LongTensor([0,1,2,2]) inputs = inputs.view((1,3,4)) outputs = outputs.view((1,4)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(weight=weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
#!/use/bin/envpython#-*-conding:utf-8-*-#author:shanshan"""写代码1,实现用户输入用户名和密码,当用户名为seven且密码为123时,显示登陆成功,否则登陆失败!2,实现用户输入用户名和密码,当用户名为seven且密码为123时,显示登陆成功,否则登陆失败,失败时允许重复输入三次3,实现用户输入用户名和密码,当用户
Matplotlib绘制图像显示中文的时候,中文会变成小方格子,下面这篇文章主要给大家介绍了关于如何彻底解决Python中matplotlib不显示中文问题的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
1、强制等待(sleep)fromtimeimportsleepsleep(3)#强制等待3秒缺点:由于Web加载的速度取决于测试的硬件、网速、服务器的响应时间等因素。如果等待时间太长,容
本篇文章给大家带来了关于python的相关知识,torch.Tensor 是一种包含单一数据类型元素的多维矩阵,类似于 numpy 的 array,下面一起来看一下Pytorch中的tensor数据结构,希望对大家有帮助。
这篇文章主要介绍了Python字符串中如何去除数字之间的逗号,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008