Focal Loss理解

 

------------------------------------------------------2023年12月11日15:25:59-------------------------------------------

focal loss可降低大量简单样本的在训练过程中所占权重,这句话的意思是,训练样本中一些相似的训练数据会在训练过程中被舍弃,因为之前的相似样本已经参与训练了。但是,这也不是意味着可以直接不经挑选的扩充训练样本,举个例子,有个分类模型区分人和狗,在人的类别中,有白人黄人和黑人,最好的情况是三种人各占三分之一,如果你一味的只扩充白种人,那即使你使用了focal loss,但是训练数据中黑种人的特征太少,分类效果永远也上不去,所以说在扩充样本时,一定要是更有质量的数据。低质量的数据,即对模型效果没有作用,又会增加训练时间。

-----------------------------------------------分割线---------------------------------------------------------------

1. 总述

Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。

 

2. 损失函数形式

Focal loss是在交叉熵损失函数基础上进行的修改,首先回顾二分类交叉上损失:

是经过激活函数的输出,所以在0-1之间。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。那么Focal loss是怎么改进的呢?

 

 

 

首先在原有的基础上加了一个因子,其中gamma>0使得减少易分类样本的损失。使得更关注于困难的、错分的样本。

例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效。

此外,加入平衡因子alpha,用来平衡正负样本本身的比例不均:文中alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。

 

只添加alpha虽然可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。

gamma调节简单样本权重降低的速率,当gamma为0时即为交叉熵损失函数,当gamma增加时,调整因子的影响也在增加。实验发现gamma为2是最优。

 

3. 总结

作者认为one-stage和two-stage的表现差异主要原因是大量前景背景类别不平衡导致。作者设计了一个简单密集型网络RetinaNet来训练在保证速度的同时达到了精度最优。在双阶段算法中,在候选框阶段,通过得分和nms筛选过滤掉了大量的负样本,然后在分类回归阶段又固定了正负样本比例,或者通过OHEM在线困难挖掘使得前景和背景相对平衡。而one-stage阶段需要产生约100k的候选位置,虽然有类似的采样,但是训练仍然被大量负样本所主导。

代码实现:

 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as F

from IPython.display import display
class FocalLoss(nn.Module):

    def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

 

 

 

 

posted @ 2022-09-20 09:16  海_纳百川  阅读(66)  评论(0编辑  收藏  举报
本站总访问量