Hit Ratio(HR)
简单点说,就是这个batch里面的top k里面有没有target item
def compute_hit_ratio_cos_sim(label, pred, k):
"""
label: bs*bs
pred: bs*bs
"""
print("pred shape:", tf.shape(pred)[0], pred.get_shape()[0], pred.shape.as_list())
top_k_indices = tf.cond(tf.less(tf.constant(k), tf.shape(pred)[0]), lambda: tf.math.top_k(pred, k).indices, lambda: tf.math.top_k(pred, 15).indices) # 如果k小于bs,选择k;如果k大于bs,选择top15 --> 获取每一行topk元素的index, [bs, k]
# top_k_indices = tf.math.top_k(pred, k).indices
# [bs, k, bs] --> [bs, bs] 先变成one-hot再相加,每一行有k个1,其余是0
is_top_k = tf.reduce_sum(tf.one_hot(top_k_indices, tf.shape(pred)[1]), axis=1) #cos_sim (bs, bs) [[0, 1, 1, 0, 1, ...], [1, 1, 0, 0, 1, ...],]
# print("is_top_k vec dim: ", is_top_k.get_shape())
# 计算hit—rate
hit_ratio = tf.reduce_sum(tf.cast(is_top_k * label, tf.float32)) / tf.reduce_sum(tf.cast(label, tf.float32)+1e-5)
return hit_ratio