RecursiveCharacterTextSplitter和CharacterTextSplitter代码随读
直接说结论:优先使用RecursiveCharacterTextSplitter,输入一个separator list。优先划分大的符号放到前面,比如句号,分号,划分小块的放到后面。
如果想让separator 不生效,就放一个原文中不会出现的一个符号,如果separator 为空的话,会有一个默认值self._separators = separators or ["\n\n", "\n", " ", ""]
先来看看CharacterTextSplitter,这个separator是一个字符串,说明是根据一个字符串进行分割,之后进行合并和chunk size
第一步就是_split_text_with_regex,其实没有什么说的,就是根据用户输入的符号, 切分成小块
第二步就是比较核心的_merge_splits,这个_merge_splits比较核心。大概
1,判断下一个d是否可以加入:
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size
如果可以加入,
直接
current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0)
就比较简单。
关键如果下一个d加入后,大于_chunk_size了,那么就不能加入了,这个时候,由于有overlap的影响,所以还要处理下当前的current_doc。
看到核心的步骤是:
if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc)
如果有结果,那么当前list就是要合并的,直接加入docs,接下来就是处理下一个块从哪里开始了
就是下面的代码。不断把前面的结果推出去,如果剩下的长度小于_chunk_overlap,并且 加入下个d后,小于chunksize。这个你想呀,如果加入d,直接大于chunksize了,下一个肯定不能加了,所以是控制让尽量多的把前面的结果pop出去
如果当前文档的大小超过了self._chunk_overlap
或者包含下一个分割字符串会导致超过设定的self._chunk_size
,则从current_doc
中移除前面的字符串以确保当前文本块不会过大。
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
举个例子:
假设我们有以下字符串分段列表:
splits = ["Hello", "world", "this", "is", "a", "test"]
我们要将它们拼接成长度不超过10的文本块。根据我们的代码逻辑,步骤如下:
-
初始化变量
docs = []
current_doc = []
total = 0
-
开始逐个遍历
splits
中的字符串。 -
当处理到
"Hello"
(长度为5),我们添加它到current_doc
,更新total
至5 -
接下来处理
"world"
(长度为5)。此时,total + _len = 5 + 5 = 10
。这个长度没有超过我们的_chunk_size
,所以我们将"world"
也添加到current_doc
,这时total
更新至10。 -
根据您的指定值,
_chunk_size
等于 10,_chunk_overlap
等于 5,确实,在遇到"this"
时,while
循环会被执行。让我们重新审视这个例子。当前文本块
current_doc
包含["Hello", "world"]
,并且total
是 10。我们想要添加"this"
(长度为 4),但这将导致总长度达到 14 (10 + 4
), 超过了self._chunk_size
的限制。所以会执行以下代码:
python复制代码if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # 这里 total > self._chunk_overlap 成立 (10 > 5),因此会执行 while 循环 while total > self._chunk_overlap or (total + _len > self._chunk_size and total > 0): total -= len(current_doc[0]) + 1 # 假设 separator_len 为 1 current_doc = current_doc[1:]
在 while 循环中:
-
我们将移除
current_doc
中的第一个元素"Hello"
,并从total
中减去"Hello"
的长度加上分隔符的长度,即total -= 5 + 1
。 -
然后,
current_doc
变为["world"]
,total
更新为 4。
现在,
total
已经小于了_chunk_overlap
,因此循环会停止。这意味着"world"
成为一个新的文本块的开始。然后代码继续向
current_doc
添加"this"
并更新total
至 8 (4 + 4
)。这一过程会一直重复,直到所有的字符串分段都被遍历完毕。对此框架代码的最终结果如下:
"Hello"
被作为单独的块添加到docs
列表。"world"
和"this"
被组合成下一个文本块,因为它们的组合长度没有超过self._chunk_size
。然后,根据剩下的字符串分段,会继续形成新的块。
最终,文本块的列表
docs
会包含所有被处理后长度不超过self._chunk_size
且在可能的情况下不小于self._chunk_overlap
的字符串组合。这种方法主要用于确保数据块都是大小适中的,这样既可以提供给如语言模型这样的处理系统处理,又不至于丢失太多可能重要的上下文信息。
-
-
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) docs = [] current_doc: List[str] = [] total = 0 for d in splits: _len = self._length_function(d) if ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size ): if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) + ( separator_len if len(current_doc) > 1 else 0 ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs
class CharacterTextSplitter(TextSplitter): """Splitting text that looks at characters.""" def __init__( self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator self._is_separator_regex = is_separator_rege def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. separator = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) splits = _split_text_with_regex(text, separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator return self._merge_splits(splits, _separator) def _split_text_with_regex( text: str, separator: str, keep_separator: bool ) -> List[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({separator})", text) splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] if len(_splits) % 2 == 0: splits += _splits[-1:] splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s != ""]
最后看一下RecursiveCharacterTextSplitter。这个separators输入的一个list。
首先用第一个分割符号尽心分割,如果分割后的chunk都小于chunsize,那就非常好办了,直接合并就好了,不需要其他的分割符号。
如果分割后chunk大于chunk size,那么就要使用下一个分割符进行分割,尽心递归运算。
代码的主要工作流程如下:
-
初始化最终的文本块列表
final_chunks
为一个空列表。 -
默认将
separators
中的最后一个元素作为分隔符separator
。 -
初始化一个新的分隔符列表
new_separators
。 -
遍历输入的
separators
列表,根据每个分隔符_s
和self._is_separator_regex
的值:- 如果
self._is_separator_regex
为真,说明_s
应当被解释为正则表达式,因此不需要转义。 - 如果
self._is_separator_regex
为假,说明_s
应当被当作普通字符串分隔符,并使用re.escape(_s)
转义任何特殊字符。
- 如果
-
判断
_separator
是否匹配text
中的部分文本:- 如果
_s
是一个空字符串,或者确实在text
中找到匹配,将_s
设置为separator
并记录剩余的分隔符到new_separators
。
- 如果
-
使用确定的分隔符
separator
来通过正则表达式分割文本,得到splits
。 -
遍历分割后的
splits
:- 如果分割的字符串
s
的长度小于预设的self._chunk_size
阈值,那么直接将其加入到_good_splits
列表中。 - 如果分割的字符串长度不小于阈值:
- 判断
_good_splits
是否包含元素,如果有,则将_good_splits
中的字符串合并并加入到final_chunks
中; - 如果没有更多的新分隔符
new_separators
,意味着不能进一步分割字符串,所以直接将它加入到final_chunks
; - 如果还有新的分隔符,递归调用
_split_text
方法将字符串进一步分割,然后扩充到final_chunks
。
- 判断
- 如果分割的字符串
-
检查
_good_splits
在最后一个字符串处理后是否仍有未处理的字符串。如果有,则合并它们并加入到final_chunks
。 -
返回
final_chunks
作为包含所有分割文本块的最终列表。
这个方法的核心在于对文本进行递归分割,每个分割后的文本块都要符合一定长度的限制(self._chunk_size
)。如果文本块长度过长,方法会根据 new_separators
进一步分割文本。通过使用多重分隔符和考虑各个分隔符是否作为正则表达式,这个方法能够灵活地适应各种文本分割需求。递归在这个方法中起到了重要的作用,它允许对过长的文本块进行层级式的细分处理。
class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. Recursively tries to split by different characters to find one that works. """ def __init__( self, separators: Optional[List[str]] = None, keep_separator: bool = True, is_separator_regex: bool = False, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] self._is_separator_regex = is_separator_regex def _split_text(self, text: str, separators: List[str]) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): _separator = _s if self._is_separator_regex else re.escape(_s) if _s == "": separator = _s break if re.search(_separator, text): separator = _s new_separators = separators[i + 1 :] break _separator = separator if self._is_separator_regex else re.escape(separator) splits = _split_text_with_regex(text, _separator, self._keep_separator) # Now go merging things, recursively splitting longer texts. _good_splits = [] _separator = "" if self._keep_separator else separator for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> List[str]: return self._split_text(text, self._separators)