RecursiveCharacterTextSplitter和CharacterTextSplitter代码随读

直接说结论:优先使用RecursiveCharacterTextSplitter,输入一个separator list。优先划分大的符号放到前面,比如句号,分号,划分小块的放到后面。

 

如果想让separator 不生效,就放一个原文中不会出现的一个符号,如果separator 为空的话,会有一个默认值self._separators = separators or ["\n\n", "\n", " ", ""]

 

 

 

先来看看CharacterTextSplitter,这个separator是一个字符串,说明是根据一个字符串进行分割,之后进行合并和chunk size

 

第一步就是_split_text_with_regex,其实没有什么说的,就是根据用户输入的符号, 切分成小块

第二步就是比较核心的_merge_splits,这个_merge_splits比较核心。大概

1,判断下一个d是否可以加入:

total + _len + (separator_len if len(current_doc) > 0 else 0)
                > self._chunk_size
如果可以加入,
直接
            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
就比较简单。
关键如果下一个d加入后,大于_chunk_size了,那么就不能加入了,这个时候,由于有overlap的影响,所以还要处理下当前的current_doc。
看到核心的步骤是:
       if len(current_doc) > 0:
                    doc = self._join_docs(current_doc, separator)
                    if doc is not None:
                        docs.append(doc)
如果有结果,那么当前list就是要合并的,直接加入docs,接下来就是处理下一个块从哪里开始了
就是下面的代码。不断把前面的结果推出去,如果剩下的长度小于_chunk_overlap,并且 加入下个d后,小于chunksize。这个你想呀,如果加入d,直接大于chunksize了,下一个肯定不能加了,所以是控制让尽量多的把前面的结果pop出去
如果当前文档的大小超过了 self._chunk_overlap 或者包含下一个分割字符串会导致超过设定的 self._chunk_size,则从 current_doc 中移除前面的字符串以确保当前文本块不会过大。

 while total > self._chunk_overlap or (
                        total + _len + (separator_len if len(current_doc) > 0 else 0)
                        > self._chunk_size
                        and total > 0
                    ):
                        total -= self._length_function(current_doc[0]) + (
                            separator_len if len(current_doc) > 1 else 0
                        )
                        current_doc = current_doc[1:]

举个例子:

假设我们有以下字符串分段列表:

复制代码
splits = ["Hello", "world", "this", "is", "a", "test"]

我们要将它们拼接成长度不超过10的文本块。根据我们的代码逻辑,步骤如下:

  1. 初始化变量

    • docs = []
    • current_doc = []
    • total = 0
  2. 开始逐个遍历 splits 中的字符串。

  3. 当处理到 "Hello"(长度为5),我们添加它到 current_doc,更新 total 至5

    1. 接下来处理 "world"(长度为5)。此时,total + _len = 5 + 5 = 10。这个长度没有超过我们的 _chunk_size,所以我们将 "world" 也添加到 current_doc,这时 total 更新至10。

    2. 根据您的指定值,_chunk_size 等于 10,_chunk_overlap 等于 5,确实,在遇到 "this" 时,while 循环会被执行。让我们重新审视这个例子。

      当前文本块 current_doc 包含 ["Hello", "world"],并且 total 是 10。我们想要添加 "this"(长度为 4),但这将导致总长度达到 14 (10 + 4), 超过了 self._chunk_size 的限制。

      所以会执行以下代码:

      python复制代码
      if len(current_doc) > 0:
          doc = self._join_docs(current_doc, separator)
          if doc is not None:
              docs.append(doc)
      
      # 这里 total > self._chunk_overlap 成立 (10 > 5),因此会执行 while 循环
      while total > self._chunk_overlap or (total + _len > self._chunk_size and total > 0):
          total -= len(current_doc[0]) + 1  # 假设 separator_len 为 1
          current_doc = current_doc[1:]
      

      在 while 循环中:

      1. 我们将移除 current_doc 中的第一个元素 "Hello",并从 total 中减去 "Hello" 的长度加上分隔符的长度,即 total -= 5 + 1

      2. 然后,current_doc 变为 ["world"]total 更新为 4。

      现在,total 已经小于了 _chunk_overlap,因此循环会停止。这意味着 "world" 成为一个新的文本块的开始。

      然后代码继续向 current_doc 添加 "this" 并更新 total 至 8 (4 + 4)。这一过程会一直重复,直到所有的字符串分段都被遍历完毕。

      对此框架代码的最终结果如下:

      1. "Hello" 被作为单独的块添加到 docs 列表。
      2. "world" 和 "this" 被组合成下一个文本块,因为它们的组合长度没有超过 self._chunk_size。然后,根据剩下的字符串分段,会继续形成新的块。

      最终,文本块的列表 docs 会包含所有被处理后长度不超过 self._chunk_size 且在可能的情况下不小于 self._chunk_overlap 的字符串组合。

      这种方法主要用于确保数据块都是大小适中的,这样既可以提供给如语言模型这样的处理系统处理,又不至于丢失太多可能重要的上下文信息。

 
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
        # We now want to combine these smaller pieces into medium size
        # chunks to send to the LLM.
        separator_len = self._length_function(separator)

        docs = []
        current_doc: List[str] = []
        total = 0
        for d in splits:
            _len = self._length_function(d)
            if (
                total + _len + (separator_len if len(current_doc) > 0 else 0)
                > self._chunk_size
            ):
                if total > self._chunk_size:
                    logger.warning(
                        f"Created a chunk of size {total}, "
                        f"which is longer than the specified {self._chunk_size}"
                    )
                if len(current_doc) > 0:
                    doc = self._join_docs(current_doc, separator)
                    if doc is not None:
                        docs.append(doc)
                    # Keep on popping if:
                    # - we have a larger chunk than in the chunk overlap
                    # - or if we still have any chunks and the length is long
                    while total > self._chunk_overlap or (
                        total + _len + (separator_len if len(current_doc) > 0 else 0)
                        > self._chunk_size
                        and total > 0
                    ):
                        total -= self._length_function(current_doc[0]) + (
                            separator_len if len(current_doc) > 1 else 0
                        )
                        current_doc = current_doc[1:]
            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
        doc = self._join_docs(current_doc, separator)
        if doc is not None:
            docs.append(doc)
        return docs

 

 

class CharacterTextSplitter(TextSplitter):
    """Splitting text that looks at characters."""

    def __init__(
        self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
    ) -> None:
        """Create a new TextSplitter."""
        super().__init__(**kwargs)
        self._separator = separator
        self._is_separator_regex = is_separator_rege


    def split_text(self, text: str) -> List[str]:
        """Split incoming text and return chunks."""
        # First we naively split the large input into a bunch of smaller ones.
        separator = (
            self._separator if self._is_separator_regex else re.escape(self._separator)
        )
        splits = _split_text_with_regex(text, separator, self._keep_separator)
        _separator = "" if self._keep_separator else self._separator
        return self._merge_splits(splits, _separator)


def _split_text_with_regex(
    text: str, separator: str, keep_separator: bool
) -> List[str]:
    # Now that we have the separator, split the text
    if separator:
        if keep_separator:
            # The parentheses in the pattern keep the delimiters in the result.
            _splits = re.split(f"({separator})", text)
            splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
            if len(_splits) % 2 == 0:
                splits += _splits[-1:]
            splits = [_splits[0]] + splits
        else:
            splits = re.split(separator, text)
    else:
        splits = list(text)
    return [s for s in splits if s != ""]


 

 

 

最后看一下RecursiveCharacterTextSplitter。这个separators输入的一个list。

首先用第一个分割符号尽心分割,如果分割后的chunk都小于chunsize,那就非常好办了,直接合并就好了,不需要其他的分割符号。

如果分割后chunk大于chunk size,那么就要使用下一个分割符进行分割,尽心递归运算。

 

 

代码的主要工作流程如下:

  1. 初始化最终的文本块列表 final_chunks 为一个空列表。

  2. 默认将 separators 中的最后一个元素作为分隔符 separator

  3. 初始化一个新的分隔符列表 new_separators

  4. 遍历输入的 separators 列表,根据每个分隔符 _sself._is_separator_regex 的值:

    • 如果 self._is_separator_regex 为真,说明 _s 应当被解释为正则表达式,因此不需要转义。
    • 如果 self._is_separator_regex 为假,说明 _s 应当被当作普通字符串分隔符,并使用 re.escape(_s) 转义任何特殊字符。
  5. 判断 _separator 是否匹配 text 中的部分文本:

    • 如果 _s 是一个空字符串,或者确实在 text 中找到匹配,将 _s 设置为 separator 并记录剩余的分隔符到 new_separators
  6. 使用确定的分隔符 separator 来通过正则表达式分割文本,得到 splits

  7. 遍历分割后的 splits:

    • 如果分割的字符串 s 的长度小于预设的 self._chunk_size 阈值,那么直接将其加入到 _good_splits 列表中。
    • 如果分割的字符串长度不小于阈值:
      • 判断 _good_splits 是否包含元素,如果有,则将 _good_splits 中的字符串合并并加入到 final_chunks 中;
      • 如果没有更多的新分隔符 new_separators,意味着不能进一步分割字符串,所以直接将它加入到 final_chunks
      • 如果还有新的分隔符,递归调用 _split_text 方法将字符串进一步分割,然后扩充到 final_chunks
  8. 检查 _good_splits 在最后一个字符串处理后是否仍有未处理的字符串。如果有,则合并它们并加入到 final_chunks

  9. 返回 final_chunks 作为包含所有分割文本块的最终列表。

这个方法的核心在于对文本进行递归分割,每个分割后的文本块都要符合一定长度的限制(self._chunk_size)。如果文本块长度过长,方法会根据 new_separators 进一步分割文本。通过使用多重分隔符和考虑各个分隔符是否作为正则表达式,这个方法能够灵活地适应各种文本分割需求。递归在这个方法中起到了重要的作用,它允许对过长的文本块进行层级式的细分处理。

 

class RecursiveCharacterTextSplitter(TextSplitter):
    """Splitting text by recursively look at characters.

    Recursively tries to split by different characters to find one
    that works.
    """

    def __init__(
        self,
        separators: Optional[List[str]] = None,
        keep_separator: bool = True,
        is_separator_regex: bool = False,
        **kwargs: Any,
    ) -> None:
        """Create a new TextSplitter."""
        super().__init__(keep_separator=keep_separator, **kwargs)
        self._separators = separators or ["\n\n", "\n", " ", ""]
        self._is_separator_regex = is_separator_regex

    def _split_text(self, text: str, separators: List[str]) -> List[str]:
        """Split incoming text and return chunks."""
        final_chunks = []
        # Get appropriate separator to use
        separator = separators[-1]
        new_separators = []
        for i, _s in enumerate(separators):
            _separator = _s if self._is_separator_regex else re.escape(_s)
            if _s == "":
                separator = _s
                break
            if re.search(_separator, text):
                separator = _s
                new_separators = separators[i + 1 :]
                break

        _separator = separator if self._is_separator_regex else re.escape(separator)
        splits = _split_text_with_regex(text, _separator, self._keep_separator)

        # Now go merging things, recursively splitting longer texts.
        _good_splits = []
        _separator = "" if self._keep_separator else separator
        for s in splits:
            if self._length_function(s) < self._chunk_size:
                _good_splits.append(s)
            else:
                if _good_splits:
                    merged_text = self._merge_splits(_good_splits, _separator)
                    final_chunks.extend(merged_text)
                    _good_splits = []
                if not new_separators:
                    final_chunks.append(s)
                else:
                    other_info = self._split_text(s, new_separators)
                    final_chunks.extend(other_info)
        if _good_splits:
            merged_text = self._merge_splits(_good_splits, _separator)
            final_chunks.extend(merged_text)
        return final_chunks

    def split_text(self, text: str) -> List[str]:
        return self._split_text(text, self._separators)

 

posted @ 2024-03-22 15:12  dmesg  阅读(750)  评论(0编辑  收藏  举报