代码笔记13 语义分割交叉熵的实现(去除背景类)
记录
算是自己一点点小小的记录,以前很少看开源的代码,都自己闷头写,最后才发现自己写的就是shi。不看不学不练啊,读开源代码不代表不自己造轮子,而是要学会别人编程的思想并学习,自己检讨我自己。
最近确实压力很大,想把这篇文章水出来,可是一来没有人带,导师是拉项目大师,只会分配杂活,教你是不可能的,因为他自己也不会。师兄们也在这样的环境下很佛系,反正也没毕业要求,没人想着写文章或者做这方面的研究。找了这个方向别的老师,也就是私下里请教几个问题,毕竟不是亲导师,能回复回复你帮帮你都算不错了。这几个月学了很多也走了很多坑很多弯路,没有人理解这种感觉吧。还是感谢我的妈妈和女朋友,总能在精神上支持我,不过我这还没搞定呢哈哈哈哈。等到弄得差不多就把学到的东西总结一下慢慢发上来。
2d交叉熵
这个我就不多说了,可以上网搜,这主要是一种信息熵来衡量信息差别的。主要问题在于,对于室内语义分割中,很大的情况我们是不需要背景类的,例如SUN-RGBD和NYU中都含有0(ignored)这一类,这一类本身不能作为一类算进损失函数中。
pytorch中封装的交叉熵函数中[1]
CLASStorch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
其中有ignored_index这一项,但其实不是我们所需要的,但我们可以通过这种方式一样实现这种功能,后续我会展示
实现方法
我一共展现了三种方法,我将完整的代码以及测试结果写在下面。
第一种是我写的,很垃圾,我的想法是将二维的特征图打成一维的,然后筛除0类别再进行交叉熵。
其实二三中都是,先做二位交叉熵,算出来每个像素的交叉熵值后,再根据label进行筛检。
第二种方法来源于github[2]
亏我当初写了快一天,比起来垃圾多了。。。
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossEntropyLoss1(nn.Module):
def __init__(self, classesnum, device):
super().__init__()
# utilize weight for different classes in Loss function, the Loss made from the proportion of the pixel classes
# self.weight = torch.FloatTensor([0.9627, 6.31385, 2.3358, 1.0000, 0.34482, 0.23155, 0.26256, 1.69797, 1.43697, 1.12408, 6.11624, 0.14348, 0.71114])
self.weight = torch.FloatTensor([0.9627, 6.31385, 2.3358])
self.weight = self.weight.to(device=device)
self.crossentropyloss = nn.CrossEntropyLoss(weight=self.weight,size_average=False, reduce=False)
# self.crossentropyloss = nn.CrossEntropyLoss()
self.classesnum = classesnum
def forward(self, score, label):
B, C, W, H = score.size()
label = label - 1
# ignore the class 0,which is the background in sunrgbd,and it does not supposed to be one class
label_mask = (label != -1) #find background pixel,True for not, False for is
score_mask = label_mask.expand(C, B, W, H).permute(1, 0, 2, 3) #expand it to the size of score maps, for exacting not background pixels
for i in range(int(B)):
per_label = label[i, :, :]
per_score = score[i, :, :, :]
per_labelmask = label_mask[i, :, :]
per_scoremask = score_mask[i, :, :, :]
per_label_ej = per_label[per_labelmask]
per_score_ej = per_score[per_scoremask].reshape(self.classesnum, -1)
if i == 0:
lab_ej = per_label_ej
score_ej = per_score_ej
elif i > 0:
lab_ej = torch.cat([lab_ej, per_label_ej], dim=0)
score_ej = torch.cat([score_ej, per_score_ej], dim=1)
lab_ej = lab_ej.unsqueeze(dim=0)
score_ej = score_ej.unsqueeze(dim=0)
t, pixelsum = lab_ej.size()
loss = self.crossentropyloss(score_ej, lab_ej)
input = F.softmax(score_ej, dim=1)
arg_max = torch.argmax(input, dim=1)
accurate = torch.sum(arg_max == lab_ej)
accurate = accurate.item()
return loss, pixelsum, accurate
class CrossEntropyLoss2(nn.Module):
def __init__(self):
super(CrossEntropyLoss2, self).__init__()
self.weight = torch.FloatTensor([0.9627, 6.31385, 2.3358])
self.ce_loss = nn.CrossEntropyLoss(weight = self.weight,
size_average=False, reduce=False)
def forward(self, inputs, targets):
mask = targets > 0
targets_m = targets.clone()
targets_m[mask] -= 1
loss_all = self.ce_loss(inputs, targets_m.long())
loss_all = torch.masked_select(loss_all, mask)
return loss_all
if __name__ == '__main__':
loss1 = CrossEntropyLoss1(classesnum=3,device='cpu')
loss2 = CrossEntropyLoss2()
loss3 = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.9627, 6.31385, 2.3358]),ignore_index=-1, size_average=False, reduce=False)
label = torch.tensor([[[1,0,2,3],
[0,3,3,2],
[1,1,0,2],
[3,2,0,2]]])
score = torch.randn([1,3,4,4])
crent1 = loss1(score,label)
crent2 = loss2(score,label)
crent3 = loss3(score,label-1)
print(crent1,'\n',crent2,'\n',crent3)
结果可见
(tensor([[ 2.0845, 6.6480, 1.6856, 4.9340, 2.6181, 7.4988, 1.6213, 3.3337,
3.8872, 4.4253, 16.8919, 10.6067]]), 12, 2)
tensor([ 2.0845, 6.6480, 1.6856, 4.9340, 2.6181, 7.4988, 1.6213, 3.3337,
3.8872, 4.4253, 16.8919, 10.6067])
tensor([[[ 2.0845, 0.0000, 6.6480, 1.6856],
[ 0.0000, 4.9340, 2.6181, 7.4988],
[ 1.6213, 3.3337, 0.0000, 3.8872],
[ 4.4253, 16.8919, 0.0000, 10.6067]]])
可见结果是一样的,不过ignored_index的方式是需要为负数的,这个通过lable-1
就可以做到,而且出来的值其实是将ignored值输出为0了,后一步还需要做非零点数的计算取平均值
Refrences
[1] https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
[2] https://github.com/JindongJiang/RedNet