Loading

Hit Ratio(HR)

【推荐系统】召回离线评估指标Hit Ratio

简单点说,就是这个batch里面的top k里面有没有target item

def compute_hit_ratio_cos_sim(label, pred, k):
    """
    label: bs*bs
    pred: bs*bs
    """
    print("pred shape:", tf.shape(pred)[0], pred.get_shape()[0], pred.shape.as_list())
    top_k_indices = tf.cond(tf.less(tf.constant(k), tf.shape(pred)[0]), lambda: tf.math.top_k(pred, k).indices, lambda: tf.math.top_k(pred, 15).indices)  # 如果k小于bs,选择k;如果k大于bs,选择top15  --> 获取每一行topk元素的index, [bs, k]
    # top_k_indices = tf.math.top_k(pred, k).indices

    # [bs, k, bs] --> [bs, bs] 先变成one-hot再相加,每一行有k个1,其余是0
    is_top_k = tf.reduce_sum(tf.one_hot(top_k_indices, tf.shape(pred)[1]), axis=1) #cos_sim (bs, bs) [[0, 1, 1, 0, 1, ...], [1, 1, 0, 0, 1, ...],]
    # print("is_top_k vec dim: ", is_top_k.get_shape())

    # 计算hit—rate
    hit_ratio = tf.reduce_sum(tf.cast(is_top_k * label, tf.float32)) / tf.reduce_sum(tf.cast(label, tf.float32)+1e-5)
    return hit_ratio
posted @ 2022-05-20 11:09  戴墨镜的长颈鹿  阅读(316)  评论(0编辑  收藏  举报