LLM大模型: FlagEmbedding-BiEncoderModel源码解析和embedding模型评估及选择
NLP常见的任务之一是高效检索:在大规模语料库中快速检索与查询相关的段落或文档;用户输入query,要在语料库中找到语义最接近、最匹配的回答!此外,还有文本分类、情感分析等下游任务需要先把文本的embedding求出来,这些功能都能通过"双塔结构"(Bi-Encoder)实现!核心思路很简单:用两个不同的encoder分别求出query的embedding和answer的embedding,然后求两种embedding之间的距离(cosin或dot product都行),找到距离topK的embedding作为最合适的answer即可!存储和查找topK的向量可以借助专业的向量数据库,比如FAISS等!
1、setence转embedding的方法:这里提供了两种方式,求平均和取第一个cls token的embedding代表整个句子的embedding;
def sentence_embedding(self, hidden_state, mask): if self.sentence_pooling_method == 'mean': s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) d = mask.sum(axis=1, keepdim=True).float() return s / d elif self.sentence_pooling_method == 'cls': return hidden_state[:, 0]
除了上述两种方式,其实还有另外两种可取:
# 最大池化 max_embedding = outputs.last_hidden_state.max(dim=1).values # 拼接多种表示 concatenated_embedding = torch.cat([cls_embedding, mean_embedding, max_embedding], dim=1)
2、把input转成embedding向量
def encode(self, features): if features is None: return None psg_out = self.model(**features, return_dict=True)#先把input通过model的forward求embedding p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])#再求整个句子的embedding if self.normlized:#归一化,利于下一步求cosin或dot product p_reps = torch.nn.functional.normalize(p_reps, dim=-1) return p_reps.contiguous()
3、求相似度:就是query和passage两个矩阵相乘,本质还是dot product
def compute_similarity(self, q_reps, p_reps): if len(p_reps.size()) == 2: return torch.matmul(q_reps, p_reps.transpose(0, 1)) return torch.matmul(q_reps, p_reps.transpose(-2, -1))
4、这个loss更简单了:直接就是cross entropy!
def compute_loss(self, scores, target): return self.cross_entropy(scores, target)
5、最核心的就是forward方法了:
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None): q_reps = self.encode(query)#两个encoder分别求embedding,这是模型叫Bi双塔的原因 p_reps = self.encode(passage) if self.training: if self.negatives_cross_device and self.use_inbatch_neg: q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) group_size = p_reps.size(0) // q_reps.size(0) if self.use_inbatch_neg:#计算两个embedding之间的相似度 scores = self.compute_similarity(q_reps, p_reps) / self.temperature # B B*G scores = scores.view(q_reps.size(0), -1) target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target = target * group_size loss = self.compute_loss(scores, target)#计算loss else: scores = self.compute_similarity(q_reps[:, None, :,], p_reps.view(q_reps.size(0), group_size, -1)).squeeze(1) / self.temperature # B G scores = scores.view(q_reps.size(0), -1) target = torch.zeros(scores.size(0), device=scores.device, dtype=torch.long) loss = self.compute_loss(scores, target) else: scores = self.compute_similarity(q_reps, p_reps) loss = None return EncoderOutput( loss=loss, scores=scores, q_reps=q_reps, p_reps=p_reps, )
只看代码感觉很抽象,这里详细介绍一下整个流程:
(1)假设下面是训练样本:
data = [ { "query": "How does one become an actor in the Telugu Film Industry?", "pos": ["How do I become an actor in film industry?"], "neg": ["What is the story of Moses and Ramesses?", "Does caste system affect economic growth of India?"] }, { "query": "Why do some computer programmers develop amazing software or new concepts, while some are stuck with basic programming work?", "pos": ["Why do some computer programmers develops amazing softwares or new concepts, while some are stuck with basics programming works?"], "neg": ["When visiting a friend, do you ever think about what would happen if you did something wildly inappropriate like punch them or destroy their furniture?", "What is the difference between a compliment and flirting?"] } ]
(2)query和回答会被分开单独转成token_ids,回答叫passage,如下:(注意,这里的token编号只是示意,不一定对,只是为了说明流程和原理)!
query = { 'input_ids': tensor([[101, 2129, 2515, 2028, 2468, 2019, 4449, 1999, 1996, 10165, 2143, 3068, 102], [101, 2339, 2079, 2070, 3274, 13193, 3285, 12460, 13191, 3021, 1997, 3749, 2135, 102]]), # 两个查询的示例 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) } passage = { 'input_ids': tensor([ [101, 2129, 2079, 1045, 2468, 2019, 4449, 1999, 10165, 2143, 3068, 102], # 第一个查询的正样本 [101, 2339, 2079, 2070, 3274, 13193, 3285, 12460, 13191, 3021, 1997, 3749, 2135, 102], # 第二个查询的正样本 [101, 2054, 2003, 1996, 2466, 1997, 7929, 1998, 10500, 1029, 102], # 第一个查询的负样本1 [101, 2515, 9397, 3600, 7462, 3599, 2964, 4100, 3600, 2630, 1997, 2290, 1029, 102], # 第一个查询的负样本2 [101, 2043, 6188, 1037, 2767, 1010, 2079, 2017, 2412, 2228, 2055, 2054, 2052, 2490, 2017, 2106, 1037, 10723, 21446, 2066, 7059, 2068, 2030, 5620, 2037, 4192, 1029, 102], # 第二个查询的负样本1 [101, 2054, 2003, 1996, 4487, 2090, 1037, 9994, 1998, 18095, 1029, 102] # 第二个查询的负样本2 ]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) }
(3)Bi双塔就体现在这里了:query和passage分别调用encoder求整个句子的embedding
q_reps = self.encode(query) # 形状 [2, embedding_dim],表示两个query的embedding p_reps = self.encode(passage) # 形状 [6, embedding_dim],表示六个passage的embedding
(4)求query和passage之间的相似度:
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
这步很关键,假设结果如下(注意:数值不一定对,这里只是说明流程和原理):
scores = tensor([ [0.8, 0.1, 0.3, 0.2, 0.4, 0.7], # 第一个query与所有passage的相似度分数 [0.5, 0.6, 0.4, 0.9, 0.2, 0.3] # 第二个query与所有passage的相似度分数 ])
上面是query和所有passage的相似度,哪些passage才是pos,哪些是neg了?这个需要区分开来吧,是如下方式做的:
target = torch.arange(2, device=scores.device, dtype=torch.long) # [0, 1] target = target * 3 # [0 * 3, 1 * 3] => [0, 3]
这样一来,target就包含了正确pos回答的位置了,如下;
target[0]
= 0 表示第一个查询的正确段落是passage[0]
。target[1]
= 3 表示第二个查询的正确段落是passage[3]
。
(5)上面的所有的铺垫和准备工作都已完成,最后一步就是coss entropy求loss了:scores和target之间要尽量对齐一致,由于target包含了pos正确的回答,所以scores对应的正确回答pos的维度数值要尽量大,其他neg维度数值要尽量小,这就从loss端区分开了pos和neg答案啦!
loss = cross_entropy(scores, target)
最后有个疑问:target包含了正样本pos位置,和scores求cross entropy,本质是通过target选择scores中最合理的维度求极值,这么来看,貌似负样本neg好像没用上?
整个流程示意图:
为什么要用bert做整个sentence的embedding?bert考虑了完整的context,天然适合表示整个sentence的语义!
(6)huggingface上那么多bert架构的embedding框架,那款合适了? https://huggingface.co/spaces/mteb/leaderboard 这里有个排名供参考借鉴:
一般情况下,选择模型的评价指标:
- Max Tokens:query和passage的长度
- Embedding Dimensions:语义是否丰富,是包罗万象,还是又精又专
- Memory Usage:自己硬件的能力能否承载
更进一步,自己找10~20条样本先embedding试试,通过TSNE降维后看看正确的回答是不是聚在一起,如果是,可以采用!
参考:
1、https://github.com/FlagOpen/FlagEmbedding
2、https://www.bilibili.com/video/BV1sQ4y137Ft/?spm_id_from=pageDriver&vd_source=241a5bcb1c13e6828e519dd1f78f35b2
3、https://huggingface.co/BAAI/bge-m3
4、https://www.53ai.com/news/qianyanjishu/816.html
5、https://www.bilibili.com/video/BV1Hk4y1X7aG/?spm_id_from=333.999.0.0&vd_source=241a5bcb1c13e6828e519dd1f78f35b2 embedding方法
6、https://github.com/blackinkkkxi/RAG_langchain