[SentencePiece]Tokenizer的原理与实现
由来
无论在使用LLM大模型时,还是使用bert等传统的模型,对字符串进行编码都是必要的,只有经过编码后的字符串才能参与到后面的模型计算。
以下是在transformers库下的编码方式,无论是什么模型,AutoTokenizer隐藏了很多细节:
query = 'hello' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) inputs = tokenizer.encode(query)
好处是在使用时不用管tokenizer的底层实现,只需要看看配置就可以了,但当需要自己去实现端到端的LLM推理时,就有点摸不着头脑了。
拆解transformers
因为transformers的库是python编写的,所以我们可以直接扒开里面的源码,看看他们的具体实现,这里以网易的BCE-Embedding为例,看看里面都做了些什么。
首先看到BCE-Embedding是在XLMRobertaModel下重新训练了语料,训练之后的长度是250005,包含了250000个正常token和5个特殊token。
编码方式
这5个特殊token可以在模型初始化时看到:
bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>", unk_token="<unk>", pad_token="<pad>", mask_token="<mask>",
那么正常token是怎么保存的呢,可以看到其内部使用的是google的sentencepiece来保存的:
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(str(vocab_file))
需要注意的是,XLMRobertaModel是fairseq下的模型,那么其特殊字符的加入位置是不一样的,另外XLMRobertaModel在末尾加了<mask>字符
Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-' |
spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' |
计算流程
一个query字符串近来的流程是怎样的呢,首先经过query会经过分词变成多个token piece,具体分词算法是bpe,然后模型字典中找token piece对应的id,当然由于特殊token是后来加的,所以优先寻找特殊token。
以下是源码中的具体实现,_tokenize方法将字符串分解为多个piece,_convert_token_to_id将对应的piece转换为对应的id,解码则是反过来的过程,逻辑是一样的:
def _tokenize(self, text: str) -> List[str]: # TODO check if the t5/llama PR also applies here return self.sp_model.encode(text, out_type=str) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" if token in self.fairseq_tokens_to_ids: return self.fairseq_tokens_to_ids[token] spm_id = self.sp_model.PieceToId(token) # Need to return unknown token if the SP model returned 0 return spm_id + self.fairseq_offset if spm_id else self.unk_token_id def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" if index in self.fairseq_ids_to_tokens: return self.fairseq_ids_to_tokens[index] return self.sp_model.IdToPiece(index - self.fairseq_offset) def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (strings for sub-words) in a single string.""" out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() return out_string
仅仅到这一步,我们便知道,XLMRobertaModel是包了一层特殊token的sentencepiece,在实现时只需要实现一个字典代替sentencepiece,剩余的把特殊token加入即可。
sentencepiece的实现
这里我们不免好奇,sentencepiece是怎么实现的呢?
BPE算法
通常情况下,Tokenizer有三种粒度:word/char/subword
- word: 按照词进行分词,如:
Today is sunday
. 则根据空格或标点进行分割[today, is, sunday, .]
- character:按照单字符进行分词,就是以char为最小粒度。 如:
Today is sunday.
则会分割成[t, o, d,a,y, .... ,s,u,n,d,a,y, .]
- subword:按照词的subword进行分词。如:
Today is sunday.
则会分割成[to, day,is , s,un,day, .]
算法流程:
- 确定词表大小,即subword的最大个数V;
- 在每个单词最后添加一个,并且统计每个单词出现的频率;
- 将所有单词拆分为单个字符,构建出初始的词表,此时词表的subword其实就是字符;
- 挑出频次最高的字符对,比如说
t
和h
组成的th
,将新字符加入词表,然后将语料中所有该字符对融合(merge),即所有t
和h
都变为th
。新字符依然可以参与后续的 merge以变成更长的piece。 - 重复3,4的操作,直到词表中单词数量达到预设的阈值V或者下一个字符对的频数为1;
sentencepiece中的源码实现:
其中SymbolPair代表了piece对,左右分别代表了合并的piece的来源,Symbol带代表了一个piece,这里的Symbol采用了类似了链式的方法,是为了避免piece合并后的内存移动,只需要用prev和next记录合并前后的邻居即可,寻找最大频率的合并piece使用priority_queue的方式。
这里的实现是上文算法的345部分,其中12在可以在前置处理过程中得到。
// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, float alpha) { // util class begin struct SymbolPair { int left; // left index of this pair int right; // right index of this pair float score; // score of this pair. large is better. size_t size; // length of this piece }; class SymbolPairComparator { public: const bool operator()(SymbolPair *h1, SymbolPair *h2) { return (h1->score < h2->score || (h1->score == h2->score && h1->left > h2->left)); } }; struct Symbol { int prev; // prev index of this symbol. -1 for BOS. int next; // next index of tihs symbol. -1 for EOS. bool freeze = false; // this symbol is never be merged. string_view_ piece; }; // util class end using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>, SymbolPairComparator>; Agenda agenda; std::vector<Symbol> symbols; symbols.reserve(normalized.size()); // Reverse merge rules. key: merged symbol, value: pair of original symbols. std::unordered_map<string_view_, std::pair<string_view_, string_view_>> rev_merge; // SymbolPair holder. std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder; // Lookup new symbol pair at [left, right] and inserts it to agenda. }
- 将所有的normalized之后的string转换为单个字符:
while (!normalized.empty()) { Symbol s; // const int mblen = matcher_->PrefixMatch(normalized, &s.freeze); int mblen = std::min<int>(normalized.size(), one_char_len(normalized.data())); s.piece = string_view_(normalized.data(), mblen); s.prev = index == 0 ? -1 : index - 1; normalized.remove_prefix(mblen); s.next = normalized.empty() ? -1 : index + 1; ++index; symbols.emplace_back(s); }
这里判断单个字段的长度,取了个巧, (src & 0xFF) >> 4会进行8位截断,然后右移4位,将普通的Ascii过滤,特殊的字符会被编码位多个字节:
static inline size_t one_char_len(const char *src) { return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; }
- 记录当前piece中可能存在的piece对,并将可能的piece对合并
对应的是关键函数是MaybeAddNewSymbolPair,该函数是匿名函数,会尝试搜索合并相邻的两个piece,这里尝试合并两个piece,如果能够合并,就加入到agenda队列中,symbol_pair_holder也会保存以便下一次合并:
auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda, &rev_merge](int left, int right) { if (left == -1 || right == -1 || symbols[left].freeze || symbols[right].freeze) { return; } const string_view_ piece(symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size()); std::string piece_str(piece.to_string()); const auto it = pieces_.find(piece_str); if (it == pieces_.end()) { return; } symbol_pair_holder.emplace_back(new SymbolPair); auto *h = symbol_pair_holder.back().get(); h->left = left; h->right = right; h->score = get_score(it->second); h->size = piece.size(); agenda.push(h); // Makes `rev_merge` for resegmentation. if (is_unused(it->second)) { rev_merge[piece] = std::make_pair(symbols[left].piece, symbols[right].piece); } };
在循环中,找到agenda中分数最高的,这里有一定概率的dropout,大概是10%,接着是合并agenda中最高pair到其左边,然后更新symbols,有点类似与删除链表的操作,只不过这里是改变前后邻居,然后循环合并前后邻居:
// Main loop. while (!agenda.empty()) { SymbolPair *top = agenda.top(); agenda.pop(); // `top` is no longer available. if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() || symbols[top->left].piece.size() + symbols[top->right].piece.size() != top->size) { continue; } if (skip_merge()) continue; // Replaces symbols with `top` rule. symbols[top->left].piece = string_view_( symbols[top->left].piece.data(), symbols[top->left].piece.size() + symbols[top->right].piece.size()); // Updates prev/next pointers. symbols[top->left].next = symbols[top->right].next; if (symbols[top->right].next >= 0) { symbols[symbols[top->right].next].prev = top->left; } symbols[top->right].piece = string_view_(""); // Adds new symbol pairs which are newly added after symbol replacement. MaybeAddNewSymbolPair(symbols[top->left].prev, top->left); MaybeAddNewSymbolPair(top->left, symbols[top->left].next); }
相比与理想的算法,实际实现中多了一步,即将is_unused的id的pair再重新拆回去,具体的逆向字典则在merge时保存:
std::function<void(string_view_, EncodeResult *)> resegment; resegment = [this, &resegment, &rev_merge](string_view_ w, EncodeResult *output) -> void { std::string w_str(w.to_string()); const int id = piece_to_id(w_str); // std::cout << "piece: " << w << ", id = " << id << std::endl; if (id == -1 || !is_unused(id)) { output->emplace_back(w, id); return; } const auto p = rev_merge.find(w); if (p == rev_merge.end()) { // This block will never be called, as `rev_merge` stores all the // resegmentation info for unused id. output->emplace_back(w, id); return; } // Recursively resegment left and right symbols. resegment(p->second.first, output); resegment(p->second.second, output); };
经过以上的算法转换之后,便可以将string转换为对应的ids了。
题外
除了BPE的编码方式,还有BBPE、WordPiece、Unigram等不同的方式,除了在具体处理方式上的不同,总体结构上是大同小异的。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
2023-08-26 [fastllm]多线程下动态组batch实现解析