[SentencePiece]Tokenizer的原理与实现

由来

无论在使用LLM大模型时,还是使用bert等传统的模型,对字符串进行编码都是必要的,只有经过编码后的字符串才能参与到后面的模型计算。
以下是在transformers库下的编码方式,无论是什么模型,AutoTokenizer隐藏了很多细节:

query = 'hello'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
inputs = tokenizer.encode(query)

好处是在使用时不用管tokenizer的底层实现,只需要看看配置就可以了,但当需要自己去实现端到端的LLM推理时,就有点摸不着头脑了。

拆解transformers

因为transformers的库是python编写的,所以我们可以直接扒开里面的源码,看看他们的具体实现,这里以网易的BCE-Embedding为例,看看里面都做了些什么。
首先看到BCE-Embedding是在XLMRobertaModel下重新训练了语料,训练之后的长度是250005,包含了250000个正常token和5个特殊token。

编码方式

这5个特殊token可以在模型初始化时看到:

bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",

那么正常token是怎么保存的呢,可以看到其内部使用的是google的sentencepiece来保存的:

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))

需要注意的是,XLMRobertaModel是fairseq下的模型,那么其特殊字符的加入位置是不一样的,另外XLMRobertaModel在末尾加了<mask>字符

Vocab 0 1 2 3 4 5 6 7 8 9
fairseq '<s>' '<pad>' '</s>' '<unk>' ',' '.' '▁' 's' '▁de' '-'
spm '<unk>' '<s>' '</s>' ',' '.' '▁' 's' '▁de' '-' '▁a'

计算流程

一个query字符串近来的流程是怎样的呢,首先经过query会经过分词变成多个token piece,具体分词算法是bpe,然后模型字典中找token piece对应的id,当然由于特殊token是后来加的,所以优先寻找特殊token。
以下是源码中的具体实现,_tokenize方法将字符串分解为多个piece,_convert_token_to_id将对应的piece转换为对应的id,解码则是反过来的过程,逻辑是一样的:

    def _tokenize(self, text: str) -> List[str]:
        # TODO check if the t5/llama PR also applies here
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        spm_id = self.sp_model.PieceToId(token)

        # Need to return unknown token if the SP model returned 0
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) in a single string."""
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

仅仅到这一步,我们便知道,XLMRobertaModel是包了一层特殊token的sentencepiece,在实现时只需要实现一个字典代替sentencepiece,剩余的把特殊token加入即可。

sentencepiece的实现

这里我们不免好奇,sentencepiece是怎么实现的呢?

BPE算法

通常情况下,Tokenizer有三种粒度:word/char/subword

  • word: 按照词进行分词,如: Today is sunday. 则根据空格或标点进行分割[today, is, sunday, .]
  • character:按照单字符进行分词,就是以char为最小粒度。 如:Today is sunday. 则会分割成[t, o, d,a,y, .... ,s,u,n,d,a,y, .]
  • subword:按照词的subword进行分词。如:Today is sunday. 则会分割成[to, day,is , s,un,day, .]

算法流程:

  1. 确定词表大小,即subword的最大个数V;
  2. 在每个单词最后添加一个,并且统计每个单词出现的频率;
  3. 将所有单词拆分为单个字符,构建出初始的词表,此时词表的subword其实就是字符;
  4. 挑出频次最高的字符对,比如说th组成的th,将新字符加入词表,然后将语料中所有该字符对融合(merge),即所有th都变为th。新字符依然可以参与后续的 merge以变成更长的piece。
  5. 重复3,4的操作,直到词表中单词数量达到预设的阈值V或者下一个字符对的频数为1;

sentencepiece中的源码实现:

其中SymbolPair代表了piece对,左右分别代表了合并的piece的来源,Symbol带代表了一个piece,这里的Symbol采用了类似了链式的方法,是为了避免piece合并后的内存移动,只需要用prev和next记录合并前后的邻居即可,寻找最大频率的合并piece使用priority_queue的方式。

这里的实现是上文算法的345部分,其中12在可以在前置处理过程中得到。

// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized,
                                                      float alpha) {
  // util class begin
  struct SymbolPair {
    int left;    // left index of this pair
    int right;   // right index of this pair
    float score; // score of this pair. large is better.
    size_t size; // length of this piece
  };

  class SymbolPairComparator {
  public:
    const bool operator()(SymbolPair *h1, SymbolPair *h2) {
      return (h1->score < h2->score ||
              (h1->score == h2->score && h1->left > h2->left));
    }
  };

  struct Symbol {
    int prev;            // prev index of this symbol. -1 for BOS.
    int next;            // next index of tihs symbol. -1 for EOS.
    bool freeze = false; // this symbol is never be merged.
    string_view_ piece;
  };
  // util class end

  using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
                                     SymbolPairComparator>;
  Agenda agenda;
  std::vector<Symbol> symbols;
  symbols.reserve(normalized.size());
  // Reverse merge rules. key: merged symbol, value: pair of original symbols.
  std::unordered_map<string_view_, std::pair<string_view_, string_view_>>
      rev_merge;
  // SymbolPair holder.
  std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder;
  // Lookup new symbol pair at [left, right] and inserts it to agenda.
}
  1. 将所有的normalized之后的string转换为单个字符:
  while (!normalized.empty()) {
    Symbol s;
    // const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
    int mblen =
        std::min<int>(normalized.size(), one_char_len(normalized.data()));
    s.piece = string_view_(normalized.data(), mblen);
    s.prev = index == 0 ? -1 : index - 1;
    normalized.remove_prefix(mblen);
    s.next = normalized.empty() ? -1 : index + 1;
    ++index;
    symbols.emplace_back(s);
  }

这里判断单个字段的长度,取了个巧, (src & 0xFF) >> 4会进行8位截断,然后右移4位,将普通的Ascii过滤,特殊的字符会被编码位多个字节:

static inline size_t one_char_len(const char *src) {
  return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
  1. 记录当前piece中可能存在的piece对,并将可能的piece对合并

对应的是关键函数是MaybeAddNewSymbolPair,该函数是匿名函数,会尝试搜索合并相邻的两个piece,这里尝试合并两个piece,如果能够合并,就加入到agenda队列中,symbol_pair_holder也会保存以便下一次合并:

auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda,
                                &rev_merge](int left, int right) {
    if (left == -1 || right == -1 || symbols[left].freeze ||
        symbols[right].freeze) {
      return;
    }
    const string_view_ piece(symbols[left].piece.data(),
                             symbols[left].piece.size() +
                                 symbols[right].piece.size());
    std::string piece_str(piece.to_string());
    const auto it = pieces_.find(piece_str);
    if (it == pieces_.end()) {
      return;
    }
    symbol_pair_holder.emplace_back(new SymbolPair);
    auto *h = symbol_pair_holder.back().get();
    h->left = left;
    h->right = right;
    h->score = get_score(it->second);
    h->size = piece.size();
    agenda.push(h);

    // Makes `rev_merge` for resegmentation.
    if (is_unused(it->second)) {
      rev_merge[piece] =
          std::make_pair(symbols[left].piece, symbols[right].piece);
    }
  };

在循环中,找到agenda中分数最高的,这里有一定概率的dropout,大概是10%,接着是合并agenda中最高pair到其左边,然后更新symbols,有点类似与删除链表的操作,只不过这里是改变前后邻居,然后循环合并前后邻居:

// Main loop.
  while (!agenda.empty()) {
    SymbolPair *top = agenda.top();
    agenda.pop();

    // `top` is no longer available.
    if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
        symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
            top->size) {
      continue;
    }

    if (skip_merge())
      continue;
    // Replaces symbols with `top` rule.
    symbols[top->left].piece = string_view_(
        symbols[top->left].piece.data(),
        symbols[top->left].piece.size() + symbols[top->right].piece.size());

    // Updates prev/next pointers.
    symbols[top->left].next = symbols[top->right].next;
    if (symbols[top->right].next >= 0) {
      symbols[symbols[top->right].next].prev = top->left;
    }
    symbols[top->right].piece = string_view_("");

    // Adds new symbol pairs which are newly added after symbol replacement.
    MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
    MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
  }

相比与理想的算法,实际实现中多了一步,即将is_unused的id的pair再重新拆回去,具体的逆向字典则在merge时保存:

  std::function<void(string_view_, EncodeResult *)> resegment;
  resegment = [this, &resegment, &rev_merge](string_view_ w,
                                             EncodeResult *output) -> void {
    std::string w_str(w.to_string());
    const int id = piece_to_id(w_str);
    // std::cout << "piece: " << w << ", id = " << id << std::endl;
    if (id == -1 || !is_unused(id)) {
      output->emplace_back(w, id);
      return;
    }
    const auto p = rev_merge.find(w);
    if (p == rev_merge.end()) {
      // This block will never be called, as `rev_merge` stores all the
      // resegmentation info for unused id.
      output->emplace_back(w, id);
      return;
    }
    // Recursively resegment left and right symbols.
    resegment(p->second.first, output);
    resegment(p->second.second, output);
  };

经过以上的算法转换之后,便可以将string转换为对应的ids了。

题外

除了BPE的编码方式,还有BBPE、WordPiece、Unigram等不同的方式,除了在具体处理方式上的不同,总体结构上是大同小异的。

posted @ 2024-08-26 01:18  wildkid1024  阅读(263)  评论(0编辑  收藏  举报