late chunking 源码分析-
import bisect import logging from typing import Dict, List, Optional, Tuple, Union from llama_index.core.node_parser import SemanticSplitterNodeParser from llama_index.core.schema import Document from llama_index.embeddings.huggingface import HuggingFaceEmbedding from transformers import AutoTokenizer # Set the logging level to WARNING to suppress INFO and DEBUG messages logging.getLogger('sentence_transformers').setLevel(logging.WARNING) CHUNKING_STRATEGIES = ['semantic', 'fixed', 'sentences'] class Chunker: def __init__( self, chunking_strategy: str, ): if chunking_strategy not in CHUNKING_STRATEGIES: raise ValueError("Unsupported chunking strategy: ", chunking_strategy) self.chunking_strategy = chunking_strategy self.embed_model = None self.embedding_model_name = None def _setup_semantic_chunking(self, embedding_model_name): if embedding_model_name: self.embedding_model_name = embedding_model_name self.embed_model = HuggingFaceEmbedding( model_name=self.embedding_model_name, trust_remote_code=True, embed_batch_size=1, ) self.splitter = SemanticSplitterNodeParser( embed_model=self.embed_model, show_progress=False, ) def chunk_semantically( self, text: str, tokenizer: 'AutoTokenizer', embedding_model_name: Optional[str] = None, ) -> List[Tuple[int, int]]: if self.embed_model is None: self._setup_semantic_chunking(embedding_model_name) # Get semantic nodes nodes = [ (node.start_char_idx, node.end_char_idx) for node in self.splitter.get_nodes_from_documents( [Document(text=text)], show_progress=False ) ] # Tokenize the entire text tokens = tokenizer.encode_plus( text, return_offsets_mapping=True, add_special_tokens=False, padding=True, truncation=True, ) token_offsets = tokens.offset_mapping chunk_spans = [] for char_start, char_end in nodes: # Convert char indices to token indices start_chunk_index = bisect.bisect_left( [offset[0] for offset in token_offsets], char_start ) end_chunk_index = bisect.bisect_right( [offset[1] for offset in token_offsets], char_end ) # Add the chunk span if it's within the tokenized text if start_chunk_index < len(token_offsets) and end_chunk_index <= len( token_offsets ): chunk_spans.append((start_chunk_index, end_chunk_index)) else: break return chunk_spans def chunk_by_tokens( self, text: str, chunk_size: int, tokenizer: 'AutoTokenizer', ) -> List[Tuple[int, int, int]]: tokens = tokenizer.encode_plus( text, return_offsets_mapping=True, add_special_tokens=False ) token_offsets = tokens.offset_mapping chunk_spans = [] for i in range(0, len(token_offsets), chunk_size): chunk_end = min(i + chunk_size, len(token_offsets)) if chunk_end - i > 0: chunk_spans.append((i, chunk_end)) return chunk_spans def chunk_by_sentences( self, text: str, n_sentences: int, tokenizer: 'AutoTokenizer', ) -> List[Tuple[int, int, int]]: tokens = tokenizer.encode_plus( text, return_offsets_mapping=True, add_special_tokens=False ) token_offsets = tokens.offset_mapping chunk_spans = [] chunk_start = 0 count_chunks = 0 for i in range(0, len(token_offsets)): if tokens.tokens(0)[i] in ('.', '!', '?') and ( (len(tokens.tokens(0)) == i + 1) or (tokens.token_to_chars(i).end != tokens.token_to_chars(i + 1).start) ): count_chunks += 1 if count_chunks == n_sentences: chunk_spans.append((chunk_start, i + 1)) chunk_start = i + 1 count_chunks = 0 if len(tokens.tokens(0)) - chunk_start > 1: chunk_spans.append((chunk_start, len(tokens.tokens(0)))) return chunk_spans def chunk( self, text: str, tokenizer: 'AutoTokenizer', chunking_strategy: str = None, chunk_size: Optional[int] = None, n_sentences: Optional[int] = None, embedding_model_name: Optional[str] = None, ): chunking_strategy = chunking_strategy or self.chunking_strategy if chunking_strategy == "semantic": return self.chunk_semantically( text, embedding_model_name=embedding_model_name, tokenizer=tokenizer, ) elif chunking_strategy == "fixed": if chunk_size < 4: raise ValueError("Chunk size must be >= 4.") return self.chunk_by_tokens(text, chunk_size, tokenizer) elif chunking_strategy == "sentences": return self.chunk_by_sentences(text, n_sentences, tokenizer) else: raise ValueError("Unsupported chunking strategy")
==》具体使用示例:Chunked Pooling This notebooks explains how the chunked pooling can be implemented. First you need to install the requirements: !pip install -r requirements.txt Then we load a model which we want to use for the embedding. We choose jinaai/jina-embeddings-v2-base-en but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred. from transformers import AutoModel from transformers import AutoTokenizer from chunked_pooling import chunked_pooling, chunk_by_sentences # load model and tokenizer tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) /home/michael/workspace/chunked-pooling/venv/lib/python3.10/site-packages/tqdm/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See from .autonotebook import tqdm as notebook_tqdm Now we define the text which we want to encode and split it into chunks. The chunk_by_sentences function also returns the span annotations. Those specify the number of tokens per chunk which is needed for the chunked pooling. input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area." # determine chunks chunks, span_annotations = chunk_by_sentences(input_text, tokenizer) print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"') Chunks: - "Berlin is the capital and largest city of Germany, both by area and by population." - " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits." - " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area." Now we encode the chunks with the traditional and the context-sensitive chunked pooling method: # chunk before embeddings_traditional_chunking = model.encode(chunks) # chunk afterwards (context-sensitive chunked pooling) inputs = tokenizer(input_text, return_tensors='pt') model_output = model(**inputs) embeddings = chunked_pooling(model_output, [span_annotations])[0] Finally, we compare the similarity of the word "Berlin" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method: import numpy as np cos_sim = lambda x, y:, y) / (np.linalg.norm(x) * np.linalg.norm(y)) berlin_embedding = model.encode('Berlin') for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking): print(f'similarity_new("Berlin", "{chunk}"):', cos_sim(berlin_embedding, new_embedding)) print(f'similarity_trad("Berlin", "{chunk}"):', cos_sim(berlin_embedding, trad_embeddings))
similarity_new("Berlin", "Berlin is the capital and largest city of Germany, both by area and by population."): 0.849546 similarity_trad("Berlin", "Berlin is the capital and largest city of Germany, both by area and by population."): 0.84862185 similarity_new("Berlin", " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits."): 0.82489026 similarity_trad("Berlin", " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits."): 0.7084338 similarity_new("Berlin", " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."): 0.84980094 similarity_trad("Berlin", " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."): 0.7534553
从提供的内容来看,"late chunking"是利用长上下文嵌入模型对长度超出模型容量的文本进行分块的一种方法,目的是在对文本进行分块之后保持尽量多的上下文信息,从而在后续的上下文敏感处理过程(如嵌入生成)中提高文本表示的准确性。
- 使用特定的分块策略(如句子分块法)对文本进行初步分块,但尽量保留原始的上下文,比如将文本按句子划分,而不是简单地按固定长度分割。
上下文敏感填充(Chunked Pooling):
- 初步分块后,使用上下文敏感的嵌入模型对整个文本进行编码,然后根据提前定义好的分块信息(特别是结合上下文的句子边界信息)进行专注于上下文的嵌入池化。这种方法强调在不丢失长段信息的情况下进行上下文感知的嵌入计算。
- 在对文本进行分块处理后,并不立即生成嵌入向量,而是在对整个文本进行完整嵌入处理后,根据分块策略重新整合为各个“分块”的嵌入。这与传统的分块再嵌入的方法相对比,能够保留更多的全局上下文信息。
- 将传统分块方法生成的嵌入与上下文敏感方法生成的嵌入进行对比,比如通过计算相似性比较每一个分块与特定关键词(如"Berlin")之间的差异。这种对比能揭示上下文敏感方法在保留语义一致性方面的优越性。
- 使用了Jina的嵌入模型
,因为其支持对长文本的上下文敏感嵌入生成。 - 通过对文本分块后进行完整的文本输入和编码,最后利用分块策略进行上下文感知的结果整合。
“late chunking”策略可以在多个高级模型中实现,并不限于Jina的嵌入模型。比如,GPT-3.5/4、Longformer、RoBERTa、T5等模型都具有较强的上下文理解能力,能够处理长文本并生成上下文感知的嵌入。在选择模型时要考虑文本长度、应用需求和可
RAG(检索增强生成)是最著名的应用之一,它需要将文档集合拆分成较小的文本块。这些块通常存储在向量数据库中,向量表示由文本嵌入模型创建。在运行时,相同的嵌入模型将查询文本编码为向量表示,用于识别相关的存储文本块。然后将它们传递给大型语言模型 (LLM),该模型根据检索到的文本合成对查询的响应。
这种简单的 RAG 方法并非没有挑战。长距离上下文依赖关系(即当相关信息分散在多个块中并且将文本片段脱离上下文会使它们变得无用)在这种方法中处理得特别糟糕。 在上图中,我们可以看到一篇被分成多个句子块的维基百科文章。我们可以看到,像“its”和“the city”这样的短语引用了仅在第一句中提到的“Berlin”,例如,嵌入模型很难将其链接到相应的实体以产生高质量的嵌入表示。
例如,如果我们将维基百科文章拆分成句子长度的片段,如上例所示,RAG 系统可能无法回答“柏林的人口是多少?”这样的查询。城市名称和人口永远不会同时出现在单个片段中,并且缺乏任何更大的文档上下文。呈现其中一个片段的 LLM 无法解析“它”或“这个城市”之类的照应性引用。
。这些模型支持更长的输入文本,例如 8192 个标记jina-embeddings-v2-base-en
或大约十页标准文本。这种大小的文本段不太可能具有只能通过更大的上下文来解决的上下文依赖关系。然而,我们仍然需要更小的文本块的向量表示,部分原因是 LLM 的输入大小有限,但主要是因为短嵌入向量的信息容量有限。
文本 | 相似传统 | 相似性后期分块 |
柏林是德国的首都和最大城市,无论从面积还是人口上看都是如此。” | 0.84862185 | 0.849546 |
按照市区人口计算,该市居民超过 385 万,是欧盟人口最多的城市。 | 0.7084338 | 0.82489026 |
该市也是德国的一个州,是该国面积第三小的州。 | 0.7534553 | 0.84980094 |
为了验证这种方法除了几个简单的示例之外的有效性,我们使用BeIR中的一些检索基准对其进行了测试。这些检索任务包括一个查询集、一个文本文档语料库和一个 QRels 文件,该文件存储了与每个查询相关的文档 ID 的信息。为了识别查询的相关文档,可以对文档进行分块,将它们编码为嵌入索引,并确定每个查询嵌入的最相似块 (kNN)。由于每个块对应一个文档,因此可以将块的 kNN 排名转换为文档的 kNN 排名(对于在排名中多次出现的文档,仅保留第一次出现的文档)。之后,可以将结果排名与对应于真实 QRels 文件的排名进行比较,并计算 nDCG@10 等检索指标。我们使用传统分块和我们新颖的后期分块方法对各种 BeIR 数据集进行了此评估。为了将文本拆分成块,我们选择了一种简单的方法,即将测试分块为 256 个标记的字符串。传统和后期分块测试都使用了jina-embeddings-v2-small-en模型。
数据集 | 平均文档长度(字符) | 传统分块(nDCG@10) | 后期分块 (nDCG@10) | 无分块 (nDCG@10) |
科学事实 | 1498.4 | 64.20% | 66.10% | 63.89% |
特雷科维德 | 1116.7 | 63.36% | 64.70% | 65.18% |
FiQA2018 | 767.2 | 33.25% | 33.84% | 33.43% |
NFC语料库 | 1589.8 | 23.46% | 29.98% | 30.40% |
Quora | 62.2 | 87.19% | 87.19% | 87.19% |
在所有情况下,后期分块都提高了分数。在某些情况下,它的表现也优于将整个文档编码为单个嵌入,而对于其他数据集,没有分块表现最好。然而,这只有在不需要对块进行排序时才有意义。我们还可以看到,文档的平均长度与通过后期分块获得的 nDCG 分数的更大改进相关。
要重现评估,您可以安装依赖项pip install .
python3 --task-name {TASK_NAME}
感谢 Isabelle Mohr(@violenil)贡献一些代码,以及 Scott Martens(@scott-martens)审阅 README。
您可以在MTEB 存储库中找到有关评估任务的更多信息,并可以在我们的论文中找到有关长输入文本模型训练的详细信息:“Jina 嵌入 2:用于长文档的 8192 个标记通用文本嵌入”。
如果您发现 Late Chunking 在您的研究中很有用,可以引用论文Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models:
title={Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models},
author={G{\"u}nther, Michael and Mohr, Isabelle and Williams, Daniel J and Wang, Bo and Xiao, Han},
journal={arXiv preprint arXiv:2409.04701},
# imports import os import json import torch import numpy as np import spacy from spacy.tokens import Doc from spacy.language import Language import transformers from transformers import AutoModel from transformers import AutoTokenizer def sentence_chunker(document, batch_size=10000): """ Given a document (string), return the sentences as chunks and span annotations (start and end indices of chunks). Using spacy to do this sentence chunking. """ # Batch with spacy nlp = spacy.blank("en") nlp.add_pipe("sentencizer", config={"punct_chars": None}) doc = nlp(document) docs = [] for i in range(0, len(document), batch_size): batch = document[i : i + batch_size] docs.append(nlp(batch)) doc = Doc.from_docs(docs) span_annotations = [] chunks = [] for i, sent in enumerate(doc.sents): span_annotations.append((sent.start, sent.end)) chunks.append(sent.text) return chunks, span_annotations def document_to_token_embeddings(model, tokenizer, document, batch_size=4096): """ Given a model and tokenizer of jinaai/jina-embeddings-v2-base-en, return token embeddings of the input text document. """ if batch_size > 8192: raise ValueError("Batch size is too large. Please use a batch size of 8192 or less.") tokenized_document = tokenizer(document, return_tensors="pt") tokens = tokenized_document.tokens() # Batch in sizes of batch_size outputs = [] for i in range(0, len(tokens), batch_size): start = i end = min(i + batch_size, len(tokens)) # subset huggingface tokenizer outputs to i : i + batch_size batch_inputs = {k: v[:, start:end] for k, v in tokenized_document.items()} with torch.no_grad(): model_output = model(**batch_inputs) outputs.append(model_output.last_hidden_state) model_output =, dim=1) return model_output def late_chunking(token_embeddings, span_annotation, max_length=None): """ Given the token-level embeddings of document and their corresponding span annotations (start and end indices of chunks in terms of tokens), late chunking pools the token embeddings for each chunk. """ outputs = [] for embeddings, annotations in zip(token_embeddings, span_annotation): if ( max_length is not None ): # remove annotations which go beyond the max-length of the model annotations = [ (start, min(end, max_length - 1)) for (start, end) in annotations if start < (max_length - 1) ] pooled_embeddings = [] for start, end in annotations: if (end - start) >= 1: pooled_embeddings.append( embeddings[start:end].sum(dim=0) / (end - start) ) print("pooled_embeddings:", pooled_embeddings) pooled_embeddings = [ embedding.numpy() for embedding in pooled_embeddings ] outputs.append(pooled_embeddings) return outputs with open("", "r") as f: document = print(f"{document[:150]}...") tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) chunks, span_annotations = sentence_chunker(document) token_embeddings = document_to_token_embeddings(model, tokenizer, document) chunk_embeddings = late_chunking(token_embeddings, [span_annotations])[0] def late_chunking_query_by_milvus(query, top_k = 3): query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten() res = collection_name=collection, data=[query_vector.tolist()], limit=top_k, output_fields=["id", "content"], ) return [item.get("entity").get("content") for items in res for item in items] def late_chunking_query_by_cosine_sim(query, k = 3): cos_sim = lambda x, y:, y) / (np.linalg.norm(x) * np.linalg.norm(y)) query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten() results = np.empty(len(chunk_embeddings)) for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)): results[i] = cos_sim(query_vector, embedding) results_order = results.argsort()[::-1] return np.array(chunks)[results_order].tolist()[:k] print(late_chunking_query_by_cosine_sim("What are new features in milvus 2.4.13", 3)) embeddings_traditional_chunking = model.encode(chunks[:10]) import numpy as np cos_sim = lambda x, y:, y) / (np.linalg.norm(x) * np.linalg.norm(y)) milvus_embedding = model.encode('milvus 2.4.13') for chunk, late_chunking_embedding, traditional_embeddings in zip(chunks[4:8], chunk_embeddings[4:8], embeddings_traditional_chunking[4:8]): print(f'similarity_late_chunking("milvus 2.4.13", "{chunk}"):', cos_sim(milvus_embedding, late_chunking_embedding)) print(f'similarity_traditional("milvus 2.4.13", "{chunk}"):', cos_sim(milvus_embedding, traditional_embeddings))
['\n\n### Critical bug fixes\n\n- Resolved a deadlock issue if the QueryNode crashed during shard client initialization ([#37354](', '\n\n## v2.4.15\n\nRelease Date: November 5, 2024\n\n| Milvus version | Python SDK version | Java SDK version | Node.js SDK version |\n|----------------|--------------------|------------------|---------------------|\n| 2.4.15 | 2.4.9 | 2.4.8 | 2.4.9 |\n\nMilvus 2.4.15 was a critical bug-fix release focused on enhancing system stability, performance, and compatibility.', '92](, [#31754](, [#32139](\n- Import tasks now support waiting for data index completion ([#31733](\n- Enhanced Import compatibility ([#32121](, task scheduling ([#31475](, and limits on imported file size and number ([#31542](\n- Code simplification efforts including interface standardization for type checking ([#31945](, [#31857](, removal of deprecated code and metrics ([#32079](, [#32134](, [#31535](, [#32211](, [#31935](, and normalization of constant names ([#31515](\n- New metrics for QueryCoord current target channel check point lag latency ([#31420](\n- New db label for common metrics ([#32024](\n- New metrics regarding the count of deleted, indexed, and loaded entities, with the inclusion of labels such as collectionName and dbName ([#31861](\n- Error handling improvements for mismatched vector types ([#31766](\n- Support for throwing errors instead of crashing when index cannot be built ([#31845](\n- Support for invalidating the database meta cache when dropping databases ([#32092](\n- Interface refactoring for channel distribution ([#31814]( and leader view management ([#32127](\n- Refactor channel dist manager interface ([#31814]( and Refactor leader view manager interface ([#32127](\n- Batch processing ([#31632](, adding mapping information ([#32234](, [#32249](, and avoiding usage of lock ([#31787]( to accelerate frequently invoked operations\n\n### Breaking Changes\n\n- Discontinued grouping search on binary vectors ([#31735](\n- Discontinued grouping search with hybrid search ([#31812](\n- Discontinued HNSW index on binary vectors ([#31883](\n\n### Bug Fixes\n\n- Enhanced data type and value checks for queries and insertions to prevent crashes ([#31478](, [#31653](, [#31698](, [#31842](, [#32042](, [#32251](, [#32204](\n- RESTful API bug fixes ([#32160](\n- Improved prediction of inverted index resource usage ([#31641](\n- Resolution of connection issues with etcd when authorization is enabled ([#31668](\n- Security update for nats server ([#32023](\n- Stored inverted index files into a local storage path of QueryNode instead of /tmp ([#32210](\n- Addressed datacoord memory leaks for collectionInfo ([#32243](\n- Fixes for fp16/bf16 related bugs potentially causing system panic ([#31677](, [#31841](, [#32196](\n- Resolved issues with grouping search returning insufficient results ([#32151](\n- Adjustment of search with iterators to handle offsets in the Reduce step more effectively and ensure adequate results with "reduceStopForBest" enabled ([#32088](\n\n## v2.4.0-rc.1\nRelease date: March 20, 2024\n\n| Milvus version | Python SDK version |\n|----------------|--------------------|\n| 2.4.0-rc.1 | 2.4.0 |\n\nThis release introduces several scenario-based features:\n\n- **New GPU Index - CAGRA**: Thanks to NVIDIA\'s contribution, this new GPU index offers a 10x performance boost, especially for batch searches.'] similarity_late_chunking("milvus 2.4.13", " ## v2.4.15 Release Date: November 5, 2024 | Milvus version | Python SDK version | Java SDK version | Node.js SDK version | |----------------|--------------------|------------------|---------------------| | 2.4.15 | 2.4.9 | 2.4.8 | 2.4.9 | Milvus 2.4.15 was a critical bug-fix release focused on enhancing system stability, performance, and compatibility."): 0.87544537 similarity_traditional("milvus 2.4.13", " ## v2.4.15 Release Date: November 5, 2024 | Milvus version | Python SDK version | Java SDK version | Node.js SDK version | |----------------|--------------------|------------------|---------------------| | 2.4.15 | 2.4.9 | 2.4.8 | 2.4.9 | Milvus 2.4.15 was a critical bug-fix release focused on enhancing system stability, performance, and compatibility."): 0.8466624 similarity_late_chunking("milvus 2.4.13", "This version addressed a major deadlock issue that could occur during QueryNode crashes and introduced compatibility updates for the backup tool with the database feature."): 0.840839 similarity_traditional("milvus 2.4.13", "This version addressed a major deadlock issue that could occur during QueryNode crashes and introduced compatibility updates for the backup tool with the database feature."): 0.6886654 similarity_late_chunking("milvus 2.4.13", "Additionally, Milvus 2.4.15 improved delete performance and stability through significant optimizations in L0 handling. **"): 0.8397808 similarity_traditional("milvus 2.4.13", "Additionally, Milvus 2.4.15 improved delete performance and stability through significant optimizations in L0 handling. **"): 0.84991246 similarity_late_chunking("milvus 2.4.13", "Upgrading to v2.4.15 was strongly recommended** to benefit from these critical enhancements."): 0.8456653 similarity_traditional("milvus 2.4.13", "Upgrading to v2.4.15 was strongly recommended** to benefit from these critical enhancements."): 0.7509077