人脸识别softmax损失函数

  设人脸特征向量为\(x\in R^{n}\),全连接层分类权重\(W\in R^{m\times n}\),此处假设特征向量和分类权重都已经归一化,n维向量的2范数都等于1.则经过全连接层后得到\(y=W^{T}x=(w_j^{T}x)_{j=1,2,...,m}\),其中\(w_j\)是分类类别j对应的连接权重。变换如下:
       \(w_j^{T}x = |w_j^{T}||x|cos\theta_{j}=cos\theta_{j}\)
其中\(\theta_{j}\)代表单位特征向量\(x\)与分类层的第j个连接权重(也理解为第j个类中心向量)之前的夹角,夹角越小则表明越靠近这个类中心,亦是表明属于这个类的可能性越大。经过softmax之后,得到x属于各个类的概率值
     \(softmax(y)_k =p_k= \frac{e^{cos\theta_k}}{\sum {e^{cos\theta_j}}}\)

接下来将采用交叉熵损失函数,就是常见的softmax-loss 分类损失。

为什么要对特征向量x和连接权重w进行归一化呢?

一方面测试阶段是根据样本特征间的余弦相似度来衡量,样本特征范数大小是不起作用的;另一方面,训练过程中所属类别概率最大化,也可能是x范数值大带来的,这个对于测试是没有意义的。训练表现好,而测试却不好。这是由于训练和测试没有保持一致。因此设定特征x范数固定为s即\(|x|=s\)。同样分类权重W的范数大小对于测试结果也是没有意思,之后干扰训练,因此也将其固定。

  截至目前,训练优化目标是在角度空间,使得样本特征与其对于类中心夹角最小,而这与测试过程使用余弦相似保持一致,消除了连接权重和特征向量范数的干扰。但是还有提升空间。回想一下支持向量机SVM,为了使得SVM分类效果更好,引入了margin概念,避免原始S超平面附近样本无法有效分类的问题。在角度空间也同样存在类似问题。参考下图,在训练过程中没有考虑决策边界样本无法有效区分的问题,就会出现图(a)状况,存在样本不靠近其所属类中心,汇聚在类类边界上,导致测试过程无法区分。引入margin后,这种情况就出现了很大的改善。如图(b)

image

如何引入margin呢?
我就不从逐一介绍了,主要介绍2种,一个是cosface,一个就是arcface。
(1)\(p_k= \frac{e^{s·cos\theta_k}}{\sum {e^{s·cos\theta_j}}}\) --> \(p_k= \frac{e^{s·(cos\theta_k-m)}}{\sum {e^{s·cos\theta_j} \ \ \ + \ e^{s·(cos\theta_k-m)}}}\)

此举目的,是为了在余弦空间,让加入margin后仍能正确分类,从而实现更加鲁邦的特征表示学习

(2)\(p_k= \frac{e^{s·cos\theta_k}}{\sum {e^{s·cos\theta_j}}}\) --> \(p_k= \frac{e^{s·cos(\theta_k+m)}}{\sum {e^{s·cos\theta_j}\ \ \ + \ e^{s·cos(\theta_k+m)}}}\)
此举目的,是为了在角度空间,让加入margin后仍能正确分类,从而实现更加鲁邦的特征表示学习。

(3)Arc一文对margin进行了总结,围绕角度\(\theta\)进行margin的扩展,引入3个参数\(m_1, m_2, m_3\)将原始\(cos\theta\)修改为\(cos(m_1\theta+m_2)-m_3\).于是对于样本x的真实类别k对应的概率值就修正为了\(p_k= \frac{e^{s·cos\theta_k}}{\sum {e^{s·cos\theta_j}}}\) --> \(p_k= \frac{e^{s·(cos(m_1\theta_k+m_2)+m_3)}}{\sum {e^{s·cos\theta_j}\ \ \ + \ e^{s·(cos(m_1\theta_k+m_2)+m_3)}}}\)

TODO List:
(1)相关公式推导

下面是insightface代码中有关softmax-margin的实现。代码中的loss_m1, loss_m2, loss_m3,就是对应上述公式中的m1,m2,m3.
https://github.com/deepinsight/insightface/tree/master/recognition/ArcFace/train.py
(1)这一块代码就是计算\(scos\theta\)

        s = config.loss_s
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s

        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=config.num_classes,
                                    name='fc7')

(2)这一块计算的是\(s(cos\theta - m_3)\),只是对样本真实类别位置处进行这个减去m3的操作

            if config.loss_m1 == 1.0 and config.loss_m2 == 0.0:
                s_m = s * config.loss_m3
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=config.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot

(3)剩余的else部分完全就是按公式实现\(s·(cos(m_1\theta+m_2)-m_3)\)

完整的前向计算函数如下:

def get_symbol(args):
    embedding = eval(config.net_name).get_symbol()
    all_label = mx.symbol.Variable('softmax_label')
    gt_label = all_label
    is_softmax = True
    if config.loss_name == 'softmax':  #softmax
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(config.num_classes,
                                            config.emb_size),
                                     lr_mult=config.fc7_lr_mult,
                                     wd_mult=config.fc7_wd_mult,
                                     init=mx.init.Normal(0.01))
        if config.fc7_no_bias:
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        no_bias=True,
                                        num_hidden=config.num_classes,
                                        name='fc7')
        else:
            _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
            fc7 = mx.sym.FullyConnected(data=embedding,
                                        weight=_weight,
                                        bias=_bias,
                                        num_hidden=config.num_classes,
                                        name='fc7')
    elif config.loss_name == 'margin_softmax':
        _weight = mx.symbol.Variable("fc7_weight",
                                     shape=(config.num_classes,
                                            config.emb_size),
                                     lr_mult=config.fc7_lr_mult,
                                     wd_mult=config.fc7_wd_mult,
                                     init=mx.init.Normal(0.01))
        s = config.loss_s
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(
            embedding, mode='instance', name='fc1n') * s

        fc7 = mx.sym.FullyConnected(data=nembedding,
                                    weight=_weight,
                                    no_bias=True,
                                    num_hidden=config.num_classes,
                                    name='fc7')
        if config.loss_m1 != 1.0 or config.loss_m2 != 0.0 or config.loss_m3 != 0.0:
            if config.loss_m1 == 1.0 and config.loss_m2 == 0.0:
                s_m = s * config.loss_m3
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=config.num_classes,
                                            on_value=s_m,
                                            off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if config.loss_m1 != 1.0:
                    t = t * config.loss_m1
                if config.loss_m2 > 0.0:
                    t = t + config.loss_m2
                body = mx.sym.cos(t)
                if config.loss_m3 > 0.0:
                    body = body - config.loss_m3
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label,
                                            depth=config.num_classes,
                                            on_value=1.0,
                                            off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body

参考文献:
【1】CosFace: Large Margin Cosine Loss for Deep Face Recognition
【2】ArcFace: Additive Angular Margin Loss for Deep Face Recognition

posted @ 2021-06-14 22:38  星辰大海,绿色星球  阅读(559)  评论(0编辑  收藏  举报