SimCSE的loss实现-tensorflow2
对比学习的核心就是loss的编写,记录下loss的tensorflow实现
def unsupervise_loss(y_pred, alpha=0.05):
idxs = tf.range(y_pred.shape[0])
y_true = idxs + 1 - idxs % 2 * 2
y_pred = tf.math.l2_normalize(y_pred, dim = 1)
similarities = tf.matmul(y_pred, y_pred,adjoint_b = True)
similarities = similarities - tf.eye(tf.shape(y_pred)[0]) * 1e12
similarities = similarities / alpha
print(y_true)
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, similarities, from_logits=True)
return tf.reduce_mean(loss)
def supervise_loss(y_pred, alpha=0.05):
row = tf.range(0, y_pred.shape[0], 3)
col = tf.range(y_pred.shape[0])
col = tf.squeeze(tf.where(col % 3 != 0),axis=1)
y_true = tf.range(0, len(col), 2)
y_pred = tf.math.l2_normalize(y_pred, dim = 1)
similarities = tf.matmul(y_pred, y_pred,adjoint_b = True)
similarities = tf.gather(similarities, row, axis=0)
similarities = tf.gather(similarities, col, axis=1)
similarities = similarities / alpha
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, similarities, from_logits=True)
return tf.reduce_mean(loss)
假设embedding向量维度为3
y_pred = tf.random.uniform((6,3))
refenences:
电商搜索召回. https://github.com/muyuuuu/E-commerce-Search-Recall?spm=5176.21852664.0.0.79006ebf02bd2j
SimCSE pytorch. https://github.com/zhengyanzhao1997/NLP-model/tree/main/model/model/Torch_model/SimCSE-Chinese
SimCSE的loss实现源码解读. https://zhuanlan.zhihu.com/p/377862950
SimCSE简介以及核心代码详解——无监督文本向量抽取. https://zhuanlan.zhihu.com/p/462763973