Fork me on GitHub

图像语义分割训练经验总结--图像语义分割

  最近一直在学pytorch,copy了几个经典的入门问题。现在作一下总结。

  首先,做的小项目主要有

             分类问题:Mnist手写体识别、FashionMnist识别、猫狗大战

             语义分割:Unet分割肝脏图像、遥感图像

  先把语义分割的心得总结一下,目前只是一部分,以后还会随着学习的深入慢慢往里面加新的感悟。

  1)对于二分类问题

     1. Unet输出channel:对于二分类问题,类别数为2,channel为1,用uint8的单通道灰度图像表示类别就行(0/1)。

     2. label是单通道灰度图像,直接传给损失函数。

     3. 损失函数:nn.sigmoid + nn.BCELoss / nn.BCEWithLogitsLoss,此时计算loss的ouput和label维度应该保持一致。batchsize*1*h*w

 

  2)对于多分类问题

     1. Unet输出channel: 输出channel是类别数。网络的输入是img,网络的输出是one hot编码的多通道图像。

     2. Label是单通道灰度图像,不同的灰度级表示不同的类别。用于传给损失函数,计算Loss。

      具体操作方面,第一步有人说先将Label进行one hot编码(即转换成多通道图,一个通道一个类别),这样才能用交叉熵计算损失;也有人说不需要one hot编码,直接把单通道Label作为损失函数的Label。

      其实这两个人说的都不错,但第一个人并没有用Pytorch做,而第二个人是用Pytorch和nn.CrossEntropyLoss计算损失的。

      在多分类问题中,当损失函数为nn.CrossEntropyLoss()时,它会自动把标签转换成one hot形式。所以,我们在运用交叉熵损失函数时不必将标签也转换成onehot形式。在用到这种损失函数时,直接把单通道Label作为损失函数的Label即可,而网络输入的img得到的输出是one hot编码格式。最后为了可视化输出,用argmax取到索引,把多通道图片转换成单通道图片(不同灰度级表示不同类别),再用索引对应的RGB颜色表解码(伪彩色映射)得到分割图。

      ps. 总结一下。因为单通道的Label只是用来计算Loss的,而输入图片(img)到网络的输出又是多通道图片(One hot),所以为了计算损失函数,Label传递给损失函数前是肯定要one hot一下的,只是用nn.CrossEntropyLoss时,Label自动one hot了,所以不需要你手动去转换了。此外CrossEntropyLoss还内置了softmax函数,而BCELOSS却没有内置sigmoid函数,所以在网络输出层中,如果用前者不需要加softmax层,而后者需要加sigmoid层。

     3. 损失函数:nn.CrossEntorpyLoss计算。此时计算loss的output维度为batchsize*categories*h*w,label为batchsize*h*w。此外这个损失函数内置了softmax运算。

     4. 此外,这种多分类的方法有时候精度相对不高,可以转化成多个二分类问题,最后合成在一起。

 

  3)test时有时会取torch.argmax/torch.max来得到pred_label的索引,用于计算accuracy。这点图像分类方面用的比较多。

     train的时候一般不需要这个,直接输入模型的输出和Label计算Loss,再反向传递就可以。

           pred_y = torch.max(test_output, 1)[1].data.numpy()      #返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
                accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))    #准确数/batch_size,计算准确率

 

  4)predict出图部分。(讲一点自己的看法,可能不太对)

     对于二分类问题,经过sigmoid输出float类型的概率可以直接可视化,这种情况下mask的精确度不高但是很方便,这取决于你的需求;也可以设定阈值二分类0/1再映射到255。

     对于多分类问题,经过softmax输出概率后有三种方法选择,1是设置阈值,大于阈值为1,小于阈值为0,得到的是多通道图像(感觉这样阈值影响结果很大);2是对model的输出按channel取argmax,得到的应该是单通道的图像,索引对应channnel,这种情况下不需要用softmax取概率,只需要对原始输出取argmax就可以;3是直接可视化softmax输出的概率。

 

posted @ 2020-02-15 18:22  Rser_ljw  阅读(4184)  评论(3编辑  收藏  举报