mtcnn 的 ohem 思想很简单,就是把每个 batch 样本的 loss 排序,选取前 一定比例 的 loss 作为 该 batch 的 loss;

我们可以这么理解,只选大 loss,小 loss 被弃用,相当于 小 loss 为 0,0 进行反向传播时自然不用再算了;

 

本文重点在解读 ohem 代码

def cls_ohem(cls_prob, label):
    num_keep_radio = 0.7

    # cls_prob shape is [384, 2]
    # label shape is [384], valus are one of [1 0 -1 -2]

    ### mtcnn 中分类问题只用 pos 和 neg 样本,part 和 landmark 样本不参与计算;
    ### 而 送入网络时样本并没有区分开,每个样本输出都有 cls_prob;
    ### ohem 要做的是把 pos 和 neg 对应的样本的 cls_prob 找出来,其他的不需要;

    ones = tf.ones_like(label)
    zeros = tf.zeros_like(label)

    sample_num = cls_prob.get_shape()[0]
    cls_prob_size = tf.size(cls_prob)
    cls_prob_reshape = tf.reshape(cls_prob, [cls_prob_size, -1])    ### 相当于竖着展开了
    ### 展开以后,0 2 4 6... 位置上的是 非人脸的 prob,1 3 5 7... 位置上的是 人脸的 prob

    ### 把人脸和非人脸的 prob 取出来
    raw = tf.range(sample_num) * 2 # 0 2 4 6... 这些位置上是非人脸的 prob,加上 1 个位置就是人脸的 prob
    # 把 label 中人脸的 label 变成 1, 其他变成 0
    label_filter_one = tf.where(tf.less(label, 0), zeros, label)
    cls_prob_pos_index = label_filter_one + raw
    cls_prob_face_noface = tf.squeeze(tf.gather(cls_prob_reshape, cls_prob_pos_index), axis=1)      ### shape [384]

    loss = -tf.log(cls_prob_face_noface + 1e-10)        ### 加上 1e-10 是为了防止 cls_prob_neg_pos 很小,log 后为 无穷

    ### 把 pos 和 neg 的 prob 取出来
    # 把 pos 和 neg 的label 变成 1, 其他变成 0,然后点乘
    label_filter_two = tf.where(tf.less(label, 0), zeros, ones)
    label_filter_two = tf.cast(label_filter_two, tf.float32)
    loss = loss * label_filter_two

    ### pos and neg sample number
    pos_neg_count = tf.reduce_sum(label_filter_two)
    pos_neg_hard = tf.to_int32(pos_neg_count * num_keep_radio)

    loss, _ = tf.nn.top_k(loss, k=pos_neg_hard)

    with tf.Session() as sess:
        # print(sess.run(cls_prob_pos_index))
        # print(sess.run(label_filter_two))
        # print(sess.run(loss))
        print(111, sess.run(cls_prob_face_noface))
        print(222, sess.run(label_filter_two))

    return tf.reduce_mean(loss)

cls_prob = tf.random_uniform([10, 2],0, 1, seed = 100)
# label = np.array([1,0,0,0,0,0,0,0,0,0])
label = np.array([1,0,0,0,-1,0,-2,0,0,0])       # Loss:1.4161593
print(label.shape)
loss = cls_ohem(cls_prob, label)
with tf.Session() as sess:
    print(sess.run(loss))

 

注意,不要在 取 log 之前进行浮点数之间的乘法,对精度影响很大,  【为此我折腾了一上午】

 

其中最难理解的就是下面几句

 ### 把人脸和非人脸的 prob 取出来
    raw = tf.range(sample_num) * 2 # 0 2 4 6... 这些位置上是非人脸的 prob,加上 1 个位置就是人脸的 prob
    # 把 label 中人脸的 label 变成 1, 其他变成 0
    label_filter_one = tf.where(tf.less(label, 0), zeros, label)
    cls_prob_pos_index = label_filter_one + raw
    cls_prob_face_noface = tf.squeeze(tf.gather(cls_prob_reshape, cls_prob_pos_index), axis=1)      ### shape [384]

图解如下

 

 

 

 

参考资料:

https://blog.csdn.net/zhouzongzong/article/details/94716746