条件随机场之CRF++源码详解-特征
我在学习条件随机场的时候经常有这样的疑问,crf预测当前节点label如何利用其他节点的信息、crf的训练样本与其他的分类器有什么不同、crf的公式中特征函数是什么以及这些特征函数是如何表示的。在这一章中,我将在CRF++源码中寻找答案。
输入过程
CRF++训练的入口在crf_learn.cpp文件的main函数中,在该函数中调用了encoder.cpp的crfpp_learn(int argc, char **argv)函数。在CRF++中,训练被称为encoder,显然预测就称为decoder。crfpp_learn的源码如下:
1 int crfpp_learn(int argc, char **argv) { 2 CRFPP::Param param; //存放输入的参数 3 param.open(argc, argv, CRFPP::long_options); //处理命令行输入的参数,存在param对象中 4 return CRFPP::crfpp_learn(param); 5 }
Param对象主要存放输入的参数,调用open方法处理命令行输入的参数并存储。最后调用crfpp_learn(const Param ¶m)函数,在该函数中将初始化Encoder对象encoder,并调用encoder的learn方法。
样本的处理以及特征的构造
本章的重点便是这个learn方法,该方法主要是根据输入的样本和特征模板构造特征。阅读该函数源码之前可以去CRF++官网了解一下CRF++输入的参数,以及模板文件和训练文件的格式。
1 bool Encoder::learn(const char *templfile, //模板文件 2 const char *trainfile, //训练样本 3 const char *modelfile, //模型输出文件 4 bool textmodelfile, 5 size_t maxitr, 6 size_t freq, 7 double eta, 8 double C, 9 unsigned short thread_num, 10 unsigned short shrinking_size, 11 int algorithm) { 12 std::cout << COPYRIGHT << std::endl; 13 14 CHECK_FALSE(eta > 0.0) << "eta must be > 0.0"; //CHECK_FALSE是宏定义,如果传入的条件是false,则输出异常信息 15 CHECK_FALSE(C >= 0.0) << "C must be >= 0.0"; 16 CHECK_FALSE(shrinking_size >= 1) << "shrinking-size must be >= 1"; 17 CHECK_FALSE(thread_num > 0) << "thread must be > 0"; 18 19 #ifndef CRFPP_USE_THREAD 20 CHECK_FALSE(thread_num == 1) 21 << "This architecture doesn't support multi-thrading"; 22 #endif 23 24 if (algorithm == MIRA && thread_num > 1) {//MIRAS算法无法启用多线程 25 std::cerr << "MIRA doesn't support multi-thrading. use thread_num=1" 26 << std::endl; 27 } 28 29 EncoderFeatureIndex feature_index; //所有的特征将存储在feature_index中 30 Allocator allocator(thread_num); //allocator对象主要用来做资源分配以及回收 31 std::vector<TaggerImpl* > x; //x存放输入的样本,例如:如果做词性标注的话,TaggerTmpl对象存放的是每句话,而x是所有句子 32 33 std::cout.setf(std::ios::fixed, std::ios::floatfield); 34 std::cout.precision(5); 35 36 #define WHAT_ERROR(msg) do { \ 37 for (std::vector<TaggerImpl *>::iterator it = x.begin(); \ 38 it != x.end(); ++it) \ 39 delete *it; \ 40 std::cerr << msg << std::endl; \ 41 return false; } while (0) 42 43 CHECK_FALSE(feature_index.open(templfile, trainfile)) //打开“模板文件”和“训练文件” 44 << feature_index.what(); 45 46 { 47 progress_timer pg; 48 49 std::ifstream ifs(WPATH(trainfile)); 50 CHECK_FALSE(ifs) << "cannot open: " << trainfile; 51 52 std::cout << "reading training data: " << std::flush; 53 size_t line = 0; 54 while (ifs) { //开始读取训练样本 55 TaggerImpl *_x = new TaggerImpl(); //_x存放的是一句话的内容,CRF++官网中提到,用一个空白行将每个sentence隔开 56 _x->open(&feature_index, &allocator); //做一些属性赋值,所有的句子都对应相同的feature_index和allocator对象 57 if (!_x->read(&ifs) || !_x->shrink()) { 58 WHAT_ERROR(_x->what()); 59 } 60 61 if (!_x->empty()) { 62 x.push_back(_x); 63 } else { 64 delete _x; 65 continue; 66 } 67 68 _x->set_thread_id(line % thread_num); //每个句子都会分配一个线程id,可以多线程并发处理不同的句子 69 70 if (++line % 100 == 0) { 71 std::cout << line << ".. " << std::flush; 72 } 73 } 74 75 ifs.close(); 76 std::cout << "\nDone!"; 77 } 78 79 feature_index.shrink(freq, &allocator); // 根据训练是指定的-f参数,将特征出现的频率小于freq的过滤掉 80 81 std::vector <double> alpha(feature_index.size()); // feature_index.size()返回的是maxid_,即:特征函数的个数,alpha是每个特征函数的权重,便是CRF中要学习的参数 82 std::fill(alpha.begin(), alpha.end(), 0.0); 83 feature_index.set_alpha(&alpha[0]); 84 85 std::cout << "Number of sentences: " << x.size() << std::endl; 86 std::cout << "Number of features: " << feature_index.size() << std::endl; 87 std::cout << "Number of thread(s): " << thread_num << std::endl; 88 std::cout << "Freq: " << freq << std::endl; 89 std::cout << "eta: " << eta << std::endl; 90 std::cout << "C: " << C << std::endl; 91 std::cout << "shrinking size: " << shrinking_size 92 << std::endl; 93 94 ... //省略后续代码
95 }
我阅读源码是按照深度优先遍历的方式,遇到一个函数会不断地深入进去,直到理解了该函数的功能再返回。上述源码需要重点介绍的部分,我也按照深度优先的方式记录。对于比较容易理解的部分则直接在源码中添加注释。首先看下源码第43行feature_index.open(templfile, trainfile),表面是理解是打开模板文件和训练集文件,但具体做了什么事儿呢,进入这个函数发现分别调用了两个函数。一个是EncoderFeatureIndex::openTemplate(const char *filename),这个函数主要是读取模板文件中的unigram特征和bigram特征分别存储,从官网文章中也可以知道,crf的特征分为两种特征,unigram对应的是状态特征,bigram对应的是转移特征。另一个函数是EncoderFeatureIndex::openTagSet(const char *filename),该函数读取训练集文件,获得训练集特征的数量(feature_index.xsize_属性)以及训练集中label的集合(feature_index.y_属性),以后可以用集合中label值得的索引代替label。
learn函数的第57行,有两个函数调用。一个是_x->read(&ifs),这个函数是对输入的样本做处理。解释该函数之前,我先做一个约定,以词性标注为例。我们输入的训练样本每一行代表一个词,每一列代表词的特征,多个词(多行)代表一个句子,句子与句子之间用空白行分隔。这个规则从CRF++文档中也能看出,我们就统一用句子和词表示,方便表达。那么,该函数会读取一个句子。经过层层调用,会对_x对象中几个重要的数据结构进行赋值,由于这个函数的处理逻辑不复杂,因此我直接给出最终赋值的结果。如下:
class TaggerImpl : public Tagger { FeatureIndex *feature_index_; Allocator *allocator_; std::vector<std::vector <const char *> > x_; //代表一个句子,外部vector代表多行(多个词),内部vector代表每行的多列,具体的列用char*表示 std::vector<std::vector <Node *> > node_; //相当于二位数组,node_[i][j]表示一个节点,即:第i个词是第j个label的点。如:“我”这个词是“代词” std::vector<unsigned short int> answer_; //每个词对应的label std::vector<unsigned short int> result_; };
另一个调用是_x->shrink(),该函数的主要功能就是构造特征,具体来说是调用了feature_index的FeatureIndex::buildFeatures(TaggerImpl *tagger)方法,源码如下:
#define ADD { const int id = getID(os.c_str()); \
if (id != -1) feature.push_back(id); } while (0)
bool FeatureIndex::buildFeatures(TaggerImpl *tagger) const { string_buffer os; std::vector<int> feature; FeatureCache *feature_cache = tagger->allocator()->feature_cache(); //存放是每个节点或者边对应的特征向量,节点便是node[i][j],边的概念后续会接触,暂时可以忽略 tagger->set_feature_id(feature_cache->size()); //做个标记,以后要取该句子的特征,可以从该id的位置取 for (size_t cur = 0; cur < tagger->size(); ++cur) {//遍历每个词,计算每个词的特征 for (std::vector<std::string>::const_iterator it = unigram_templs_.begin(); it != unigram_templs_.end(); ++it) { //遍历每个unigram特征 if (!applyRule(&os, it->c_str(), cur, *tagger)) {applyRule函数根据当前词(cur)以及当前的特征(如: %x[-2,0]),生成一个特征,存放在os中 return false; } ADD; //将根据特征os,获取该特征的id,如果不存在该特征,生成新的id,将该id添加到feature变量中 } feature_cache->add(feature); //将该词的特征添加到feature_cache中,add方法会将feature拷贝一份并将最后添加-1,方便后续读取 feature.clear(); } for (size_t cur = 1; cur < tagger->size(); ++cur) {//遍历每条边,计算每条边的特征 for (std::vector<std::string>::const_iterator it = bigram_templs_.begin(); it != bigram_templs_.end(); ++it) {//遍历每个bigram特征 if (!applyRule(&os, it->c_str(), cur, *tagger)) {//处理同上 return false; } ADD; } feature_cache->add(feature); feature.clear(); } return true; }
经过上面处理,最终会存储节点(单词)和边(相邻词连接)的特征列表(函数中feature变量),并存储在feature_cache中,由于在该函数中调用了set_feature_id方法,因此很容易拿到每个句子对应的特征列表。这里需要关注一下applyRule函数和ADD宏定义中的getID函数。下面我将举个例子,来直观感受下这两个函数的功能。
tempfile:
# Unigram
U00:%x[-1,0]
U01:%x[0,0]
trainfile:
0 - -1 -1 -1 -1 O
0 submit 7 0 0 0 B
1 submit 3 4 0 0 E
先看下CRF++中的特征模板,模板文件比较简单,只有unigram特征,特征的表示形如 U00:%x[a,b],开头的'U'代表unigram特征还是bigram特征,b代表的是哪列特征,a代表的是当前词的行偏移量。样本集文件更简单,只有一个句子,该句子有3个单词,每个单词有6个特征。
1) 当cur=0,遍历第一个unigram特征U00:%x[-1,0], 0代表第0个特征(第0列),-1代表前一个词的第0个特征。由于第一个词没有前一个词,所以CRF++中用_B-1代替,这部分可在源码中找到。调用applyRule将会生成"U00:_B-1"特征,调用getID函数返回的maxid_并存储在feature_index的dic_属性中,maxid_初始值为0,如果当前特征是新的则返回maxid_并更新maxid_为新值,maxid更新代码为maxid_ += (key[0] == 'U' ? y_.size() : y_.size() * y_.size()); 由于unigram是状态特征label与当前节点有关,所以加y_.size()表示y_.size()个特征函数,而bigram表示转移特征(边),与当前状态和前一个状态有关,有y_.size() * y_.size()种情况,因此加上y_.size() * y_.size(),代表y_.size()*y_.size()个特征函数。以上述例子unigram来说,对于某个词的特征,该词的label可能有y_.size()种情况,最终生成的特征函数是 f(特征='U00:_B-1', y='O')=1,f(特征='U00:_B-1', y='B')=1,f(特征='U00:_B-1', y='E')=1。总结一下,对于这个例子来说,一个unigram特征对应3状态特征函数,一个bigram特征对应9个转移特征函数。
2) 当cur=0,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3,feature变量为[0,3]
3) 当cur=1,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6
4) 当cur=1,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:0",调用getID函数,返回特征id为3, feature变量为[6,3]
5) 当cur=2,遍历第一个unigram特征U00:%x[-1,0],调用applyRule生成特征"U00:0",调用getID函数,返回特征id为6
6) 当cur=2,遍历第二个unigram特征U01:%x[0,0],调用applyRule生成特征"U01:1",调用getID函数,返回特征id为9, 此时maxid_更新为12,feature变量为[6,9]
因此,特征一共有4个,状态特征有12个,转移特征为0个,因此feature_index的maxid_为12,feature_cache的大小为5(3个节点+2条边)。本例子中只有1句话并且只有一个特征的unigram特征函数,对于多句话和多个特征函数,计算逻辑是一样的,并且都会更新到公共的变量feature_index中。
至此,就_x->shrink()的核心逻辑便梳理完毕, 同时也是整个learn函数的核心逻辑,回到learn函数的源码继续往下看,while循环会对每个句子重复进行上述操作,并将表示句子的变量x_存储到变量x中,代表整个训练集。还有需要注意的是我们平时一般用w表示待学习的参数,但在CRF++中使用变量alpha表示w。
总结
本章主要结合源码和实际的例子,了解了CRF++如何处理输入的样本,如何生成特征以及特征函数。首先,通过本章可以清晰的找到开头提到的几个问题。其次,可以学习CRF++如何定义数据结构表示条件随机场各个元素及其之间的关系,如果再仔细体会一下,就能发现CRF++里设计的数据结构和代码实现还是非常巧妙的,值得学习。如对本章内容有疑问的欢迎在留言区交流,我会及时回复,同时如有表述不对的地方,烦请指正。