[SentencePiece]Tokenizer的原理与实现
由来
无论在使用LLM大模型时,还是使用bert等传统的模型,对字符串进行编码都是必要的,只有经过编码后的字符串才能参与到后面的模型计算。
以下是在transformers库下的编码方式,无论是什么模型,AutoTokenizer隐藏了很多细节:
query = 'hello'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
inputs = tokenizer.encode(query)
好处是在使用时不用管tokenizer的底层实现,只需要看看配置就可以了,但当需要自己去实现端到端的LLM推理时,就有点摸不着头脑了。
拆解transformers
因为transformers的库是python编写的,所以我们可以直接扒开里面的源码,看看他们的具体实现,这里以网易的BCE-Embedding为例,看看里面都做了些什么。
首先看到BCE-Embedding是在XLMRobertaModel下重新训练了语料,训练之后的长度是250005,包含了250000个正常token和5个特殊token。
编码方式
这5个特殊token可以在模型初始化时看到:
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
那么正常token是怎么保存的呢,可以看到其内部使用的是google的sentencepiece来保存的:
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
需要注意的是,XLMRobertaModel是fairseq下的模型,那么其特殊字符的加入位置是不一样的,另外XLMRobertaModel在末尾加了<mask>字符
Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-' |
spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' |
计算流程
一个query字符串近来的流程是怎样的呢,首先经过query会经过分词变成多个token piece,具体分词算法是bpe,然后模型字典中找token piece对应的id,当然由于特殊token是后来加的,所以优先寻找特殊token。
以下是源码中的具体实现,_tokenize方法将字符串分解为多个piece,_convert_token_to_id将对应的piece转换为对应的id,解码则是反过来的过程,逻辑是一样的:
def _tokenize(self, text: str) -> List[str]:
# TODO check if the t5/llama PR also applies here
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
仅仅到这一步,我们便知道,XLMRobertaModel是包了一层特殊token的sentencepiece,在实现时只需要实现一个字典代替sentencepiece,剩余的把特殊token加入即可。
sentencepiece的实现
这里我们不免好奇,sentencepiece是怎么实现的呢?
BPE算法
通常情况下,Tokenizer有三种粒度:word/char/subword
- word: 按照词进行分词,如:
Today is sunday
. 则根据空格或标点进行分割[today, is, sunday, .]
- character:按照单字符进行分词,就是以char为最小粒度。 如:
Today is sunday.
则会分割成[t, o, d,a,y, .... ,s,u,n,d,a,y, .]
- subword:按照词的subword进行分词。如:
Today is sunday.
则会分割成[to, day,is , s,un,day, .]
算法流程:
- 确定词表大小,即subword的最大个数V;
- 在每个单词最后添加一个,并且统计每个单词出现的频率;
- 将所有单词拆分为单个字符,构建出初始的词表,此时词表的subword其实就是字符;
- 挑出频次最高的字符对,比如说
t
和h
组成的th
,将新字符加入词表,然后将语料中所有该字符对融合(merge),即所有t
和h
都变为th
。新字符依然可以参与后续的 merge以变成更长的piece。 - 重复3,4的操作,直到词表中单词数量达到预设的阈值V或者下一个字符对的频数为1;
sentencepiece中的源码实现:
其中SymbolPair代表了piece对,左右分别代表了合并的piece的来源,Symbol带代表了一个piece,这里的Symbol采用了类似了链式的方法,是为了避免piece合并后的内存移动,只需要用prev和next记录合并前后的邻居即可,寻找最大频率的合并piece使用priority_queue的方式。
这里的实现是上文算法的345部分,其中12在可以在前置处理过程中得到。
// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized,
float alpha) {
// util class begin
struct SymbolPair {
int left; // left index of this pair
int right; // right index of this pair
float score; // score of this pair. large is better.
size_t size; // length of this piece
};
class SymbolPairComparator {
public:
const bool operator()(SymbolPair *h1, SymbolPair *h2) {
return (h1->score < h2->score ||
(h1->score == h2->score && h1->left > h2->left));
}
};
struct Symbol {
int prev; // prev index of this symbol. -1 for BOS.
int next; // next index of tihs symbol. -1 for EOS.
bool freeze = false; // this symbol is never be merged.
string_view_ piece;
};
// util class end
using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
SymbolPairComparator>;
Agenda agenda;
std::vector<Symbol> symbols;
symbols.reserve(normalized.size());
// Reverse merge rules. key: merged symbol, value: pair of original symbols.
std::unordered_map<string_view_, std::pair<string_view_, string_view_>>
rev_merge;
// SymbolPair holder.
std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder;
// Lookup new symbol pair at [left, right] and inserts it to agenda.
}
- 将所有的normalized之后的string转换为单个字符:
while (!normalized.empty()) {
Symbol s;
// const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
int mblen =
std::min<int>(normalized.size(), one_char_len(normalized.data()));
s.piece = string_view_(normalized.data(), mblen);
s.prev = index == 0 ? -1 : index - 1;
normalized.remove_prefix(mblen);
s.next = normalized.empty() ? -1 : index + 1;
++index;
symbols.emplace_back(s);
}
这里判断单个字段的长度,取了个巧, (src & 0xFF) >> 4会进行8位截断,然后右移4位,将普通的Ascii过滤,特殊的字符会被编码位多个字节:
static inline size_t one_char_len(const char *src) {
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
- 记录当前piece中可能存在的piece对,并将可能的piece对合并
对应的是关键函数是MaybeAddNewSymbolPair,该函数是匿名函数,会尝试搜索合并相邻的两个piece,这里尝试合并两个piece,如果能够合并,就加入到agenda队列中,symbol_pair_holder也会保存以便下一次合并:
auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda,
&rev_merge](int left, int right) {
if (left == -1 || right == -1 || symbols[left].freeze ||
symbols[right].freeze) {
return;
}
const string_view_ piece(symbols[left].piece.data(),
symbols[left].piece.size() +
symbols[right].piece.size());
std::string piece_str(piece.to_string());
const auto it = pieces_.find(piece_str);
if (it == pieces_.end()) {
return;
}
symbol_pair_holder.emplace_back(new SymbolPair);
auto *h = symbol_pair_holder.back().get();
h->left = left;
h->right = right;
h->score = get_score(it->second);
h->size = piece.size();
agenda.push(h);
// Makes `rev_merge` for resegmentation.
if (is_unused(it->second)) {
rev_merge[piece] =
std::make_pair(symbols[left].piece, symbols[right].piece);
}
};
在循环中,找到agenda中分数最高的,这里有一定概率的dropout,大概是10%,接着是合并agenda中最高pair到其左边,然后更新symbols,有点类似与删除链表的操作,只不过这里是改变前后邻居,然后循环合并前后邻居:
// Main loop.
while (!agenda.empty()) {
SymbolPair *top = agenda.top();
agenda.pop();
// `top` is no longer available.
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
top->size) {
continue;
}
if (skip_merge())
continue;
// Replaces symbols with `top` rule.
symbols[top->left].piece = string_view_(
symbols[top->left].piece.data(),
symbols[top->left].piece.size() + symbols[top->right].piece.size());
// Updates prev/next pointers.
symbols[top->left].next = symbols[top->right].next;
if (symbols[top->right].next >= 0) {
symbols[symbols[top->right].next].prev = top->left;
}
symbols[top->right].piece = string_view_("");
// Adds new symbol pairs which are newly added after symbol replacement.
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
}
相比与理想的算法,实际实现中多了一步,即将is_unused的id的pair再重新拆回去,具体的逆向字典则在merge时保存:
std::function<void(string_view_, EncodeResult *)> resegment;
resegment = [this, &resegment, &rev_merge](string_view_ w,
EncodeResult *output) -> void {
std::string w_str(w.to_string());
const int id = piece_to_id(w_str);
// std::cout << "piece: " << w << ", id = " << id << std::endl;
if (id == -1 || !is_unused(id)) {
output->emplace_back(w, id);
return;
}
const auto p = rev_merge.find(w);
if (p == rev_merge.end()) {
// This block will never be called, as `rev_merge` stores all the
// resegmentation info for unused id.
output->emplace_back(w, id);
return;
}
// Recursively resegment left and right symbols.
resegment(p->second.first, output);
resegment(p->second.second, output);
};
经过以上的算法转换之后,便可以将string转换为对应的ids了。
题外
除了BPE的编码方式,还有BBPE、WordPiece、Unigram等不同的方式,除了在具体处理方式上的不同,总体结构上是大同小异的。