PyTorch:关于BCE、CE Loss的Mask分割二分类问题
形式1:输出为单通道
分析
即网络的输出 output
为 [batch_size, 1, height, width] 形状。其中 batch_szie
为批量大小,1
表示输出一个通道,height
和 width
与输入图像的高和宽保持一致。
在训练时,输出通道数是 1,网络得到的 output
包含的数值是任意的数。给定的 target
,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output
不断逼近这个标签,首先会让 output
经过一个sigmoid 函数,使其数值归一化到[0, 1],得到 output1
,然后让这个 output1
与 target
进行交叉熵计算,得到损失值,反向传播更新网络权重。最终,网络经过学习,会使得 output1
逼近target
。
训练结束后,网络已经具备让输出的 output
经过转换从而逼近 target
的能力。首先将输出的 output
通过sigmoid 函数,然后取一个阈值(一般设置为0.5),大于阈值则取1反之则取0,从而得到预测图 predict
。后续则是一些评估相关的计算。
代码实现
在这个过程中,训练的损失函数为二进制交叉熵损失函数,然后根据输出是否用到了sigmoid有两种可选的pytorch实现方式:
output = net(input) # net的最后一层没有使用sigmoid
loss_func1 = torch.nn.BCEWithLogitsLoss()
loss = loss_func1(output, target)
当网络最后一层没有使用sigmoid时,需要使用 torch.nn.BCEWithLogitsLoss()
,顾名思义,在这个函数中,拿到output首先会做一个sigmoid操作,再进行二进制交叉熵计算。上面的操作等价于
output = net(input) # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(output, target)
当然,你也可以在网络最后一层加上sigmoid操作。从而省去第二行的代码(在预测时也可以省去)。
在预测试时,可用下面的代码实现预测图的生成
output = net(input) # net的最后一层没有使用sigmoid
output = F.sigmoid(output)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...
即大于0.5的记为1,小于0.5记为0。
形式2:输出为多通道
分析
即网络的输出 output
为 [batch_size, num_class, height, width] 形状。其中 batch_szie
为批量大小,num_class
表示输出的通道数与分类数量一致,height
和 width
与输入图像的高和宽保持一致。
在训练时,输出通道数是 num_class
(这里取2),网络得到的 output
包含的数值是任意的数。给定的 target
,是一个单通道标签图,数值只有 0 和 1 这两种。为了让网络输出 output
不断逼近这个标签,首先会让 output
经过一个 softmax 函数,使其数值归一化到[0, 1],得到 output1
,在各通道中,这个数值加起来会等于1。对于target
他是一个单通道图,首先使用onehot
编码,转换成 num_class
个通道的图像,每个通道中的取值是根据单通道中的取值计算出来的,例如单通道中的第一个像素取值为1(0<= 1 <=num_class-1,这里num_class=2),那么onehot
编码后,在第一个像素的位置上,两个通道的取值分别为0,1。也就是说像素的取值决定了对应序号的通道取1,其他的通道取0,这个非常关键。上面的操作执行完后得到target1
,让这个 output1
与 target1
进行交叉熵计算,得到损失值,反向传播更新网路权重。最终,网络经过学习,会使得 output1
逼近target1
(在各通道层面上)。
训练结束后,网络已经具备让输出的 output
经过转换从而逼近 target
的能力。计算 output
中各通道每一个像素位置上,取值最大的那个对应的通道序号,从而得到预测图 predict
。后续则是一些评估相关的计算。
代码实现
在这个过程中,则可以使用交叉熵损失函数:
output = net(input) # net的最后一层没有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(output, target)
根据前面的分析,我们知道,正常的output
是 [batch_size, num_class, height, width]形状的,而target
是[batch_size, height, width]形状的,需要按照上面的分析进行转换才可以计算交叉熵,而在pytorch中,我们不需要进一步做这个处理,直接使用就可以了。
在预测试时,使用下面的代码实现预测图的生成
output = net(input) # net的最后一层没有使用sigmoid
predict = output.argmax(dim=1)
...
即得到输出后,在通道方向上找出最大值所在的索引号。
小结
总的来说,我觉得第二种方式更值得推广,一方面不用考虑阈值的选取问题;另一方面,该方法同样适用于多类别的语义分割任务,通用性更强。
reference: