fasttext源码剖析

目的:记录结合多方资料以及个人理解的剖析代码;

https://heleifz.github.io/14732610572844.html

http://www.cnblogs.com/peghoty/p/3857839.html

一:代码总体模块关联图:

核心模块是fasttext.cc以及model.cc模块,但是辅助模块也很重要,是代码的螺丝钉,以及实现了数据采取什么样子数据结构进行组织,这里的东西值得学习借鉴,而且你会发现存储训练数据的结构比较常用的手段,后期可以对比多个源码的训练数据的结构对比。

部分:螺丝钉代码的剖析

二:dictionary模版

  1 /**
  2  * Copyright (c) 2016-present, Facebook, Inc.
  3  * All rights reserved.
  4  *
  5  * This source code is licensed under the BSD-style license found in the
  6  * LICENSE file in the root directory of this source tree. An additional grant
  7  * of patent rights can be found in the PATENTS file in the same directory.
  8  */
  9 
 10 #include "dictionary.h"
 11 
 12 #include <assert.h>
 13 
 14 #include <iostream>
 15 #include <algorithm>
 16 #include <iterator>
 17 #include <unordered_map>
 18 
 19 namespace fasttext {
 20 
 21 const std::string Dictionary::EOS = "</s>";
 22 const std::string Dictionary::BOW = "<";
 23 const std::string Dictionary::EOW = ">";
 24 
 25 Dictionary::Dictionary(std::shared_ptr<Args> args) {
 26   args_ = args;
 27   size_ = 0;
 28   nwords_ = 0;
 29   nlabels_ = 0;
 30   ntokens_ = 0;
 31   word2int_.resize(MAX_VOCAB_SIZE);//建立全词的索引,hash值在0~MAX_VOCAB_SIZE-1之间
 32   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
 33     word2int_[i] = -1;
 34   }
 35 }
 36 //根据字符串,进行hash,hash后若是冲突则线性探索,找到其对应的hash位置
 37 int32_t Dictionary::find(const std::string& w) const {
 38   int32_t h = hash(w) % MAX_VOCAB_SIZE;
 39   while (word2int_[h] != -1 && words_[word2int_[h]].word != w) {
 40     h = (h + 1) % MAX_VOCAB_SIZE;
 41   }
 42   return h;
 43 }
 44 //向words_添加词,词可能是标签词
 45 void Dictionary::add(const std::string& w) {
 46   int32_t h = find(w);
 47   ntokens_++;//已处理的词
 48   if (word2int_[h] == -1) {
 49     entry e;
 50     e.word = w;
 51     e.count = 1;
 52     e.type = (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;//与给出标签相同,则表示标签词
 53     words_.push_back(e);
 54     word2int_[h] = size_++;
 55   } else {
 56     words_[word2int_[h]].count++;
 57   }
 58 }
 59 //返回纯词个数--去重
 60 int32_t Dictionary::nwords() const {
 61   return nwords_;
 62 }
 63 //标签词个数---去重
 64 int32_t Dictionary::nlabels() const {
 65   return nlabels_;
 66 }
 67 //返回已经处理的词数---可以重复
 68 int64_t Dictionary::ntokens() const {
 69   return ntokens_;
 70 }
 71 //获取纯词的ngram
 72 const std::vector<int32_t>& Dictionary::getNgrams(int32_t i) const {
 73   assert(i >= 0);
 74   assert(i < nwords_);
 75   return words_[i].subwords;
 76 }
 77 //获取纯词的ngram,根据词串
 78 const std::vector<int32_t> Dictionary::getNgrams(const std::string& word) const {
 79   int32_t i = getId(word);
 80   if (i >= 0) {
 81     return getNgrams(i);
 82   }
 83   //若是该词没有被入库词典中,未知词,则计算ngram
 84   //这就可以通过其他词的近似ngram来获取该词的ngram
 85   std::vector<int32_t> ngrams;
 86   computeNgrams(BOW + word + EOW, ngrams);
 87   return ngrams;
 88 }
 89 //是否丢弃的判断标准---这是由于无用词会出现过多的词频,需要被丢弃,
 90 bool Dictionary::discard(int32_t id, real rand) const {
 91   assert(id >= 0);
 92   assert(id < nwords_);
 93   if (args_->model == model_name::sup) return false;//非词向量不需要丢弃
 94   return rand > pdiscard_[id];
 95 }
 96 //获取词的id号
 97 int32_t Dictionary::getId(const std::string& w) const {
 98   int32_t h = find(w);
 99   return word2int_[h];
100 }
101 //词的类型
102 entry_type Dictionary::getType(int32_t id) const {
103   assert(id >= 0);
104   assert(id < size_);
105   return words_[id].type;
106 }
107 //根据词id获取词串
108 std::string Dictionary::getWord(int32_t id) const {
109   assert(id >= 0);
110   assert(id < size_);
111   return words_[id].word;
112 }
113 //hash规则
114 uint32_t Dictionary::hash(const std::string& str) const {
115   uint32_t h = 2166136261;
116   for (size_t i = 0; i < str.size(); i++) {
117     h = h ^ uint32_t(str[i]);
118     h = h * 16777619;
119   }
120   return h;
121 }
122 //根据词计算其ngram情况
123 void Dictionary::computeNgrams(const std::string& word,
124                                std::vector<int32_t>& ngrams) const {
125   for (size_t i = 0; i < word.size(); i++) {
126     std::string ngram;
127     if ((word[i] & 0xC0) == 0x80) continue;
128     for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {//n-1个词背景
129       ngram.push_back(word[j++]);
130       while (j < word.size() && (word[j] & 0xC0) == 0x80) {
131         ngram.push_back(word[j++]);
132       }
133       if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
134         int32_t h = hash(ngram) % args_->bucket;//hash余数值
135         ngrams.push_back(nwords_ + h);
136       }
137     }
138   }
139 }
140 //初始化ngram值
141 void Dictionary::initNgrams() {
142   for (size_t i = 0; i < size_; i++) {
143     std::string word = BOW + words_[i].word + EOW;
144     words_[i].subwords.push_back(i);
145     computeNgrams(word, words_[i].subwords);
146   }
147 }
148 //读取词
149 bool Dictionary::readWord(std::istream& in, std::string& word) const
150 {
151   char c;
152   std::streambuf& sb = *in.rdbuf();
153   word.clear();
154   while ((c = sb.sbumpc()) != EOF) {
155     if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' || c == '\f' || c == '\0') {
156       if (word.empty()) {
157         if (c == '\n') {//若是空行,则增加一个EOS
158           word += EOS;
159           return true;
160         }
161         continue;
162       } else {
163         if (c == '\n')
164           sb.sungetc();//放回,体现对于换行符会用EOS替换
165         return true;
166       }
167     }
168     word.push_back(c);
169   }
170   // trigger eofbit
171   in.get();
172   return !word.empty();
173 }
174 //读取文件---获取词典;初始化舍弃规则,初始化ngram
175 void Dictionary::readFromFile(std::istream& in) {
176   std::string word;
177   int64_t minThreshold = 1;//阈值
178   while (readWord(in, word)) {
179     add(word);
180     if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
181       std::cout << "\rRead " << ntokens_  / 1000000 << "M words" << std::flush;
182     }
183     if (size_ > 0.75 * MAX_VOCAB_SIZE) {//词保证是不超过75%
184       minThreshold++;
185       threshold(minThreshold, minThreshold);//过滤小于minThreshold的词,顺便排序了
186     }
187   }
188   threshold(args_->minCount, args_->minCountLabel);//目的是排序,顺带过滤词,指定过滤
189   
190   initTableDiscard();
191   initNgrams();
192   if (args_->verbose > 0) {
193     std::cout << "\rRead " << ntokens_  / 1000000 << "M words" << std::endl;
194     std::cout << "Number of words:  " << nwords_ << std::endl;
195     std::cout << "Number of labels: " << nlabels_ << std::endl;
196   }
197   if (size_ == 0) {
198     std::cerr << "Empty vocabulary. Try a smaller -minCount value." << std::endl;
199     exit(EXIT_FAILURE);
200   }
201 }
202 //缩减词,且排序词
203 void Dictionary::threshold(int64_t t, int64_t tl) {
204   sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
205       if (e1.type != e2.type) return e1.type < e2.type;//不同类型词,将标签词排在后面
206       return e1.count > e2.count;//同类则词频降序排
207     });//排序,根据词频
208   words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
209         return (e.type == entry_type::word && e.count < t) ||
210                (e.type == entry_type::label && e.count < tl);
211       }), words_.end());//删除阈值以下的词
212   words_.shrink_to_fit();//剔除
213   //更新词典的信息
214   size_ = 0;
215   nwords_ = 0;
216   nlabels_ = 0;
217   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
218     word2int_[i] = -1;//重置
219   }
220   for (auto it = words_.begin(); it != words_.end(); ++it) {
221     int32_t h = find(it->word);//重新构造hash
222     word2int_[h] = size_++;
223     if (it->type == entry_type::word) nwords_++;
224     if (it->type == entry_type::label) nlabels_++;
225   }
226 }
227 //初始化丢弃规则---
228 void Dictionary::initTableDiscard() {//t采样的阈值,0表示全部舍弃,1表示不采样
229   pdiscard_.resize(size_);
230   for (size_t i = 0; i < size_; i++) {
231     real f = real(words_[i].count) / real(ntokens_);//f概率高
232     pdiscard_[i] = sqrt(args_->t / f) + args_->t / f;//与论文貌似不一样?????
233   }
234 }
235 //返回词的频数--所以词的词频和
236 std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
237   std::vector<int64_t> counts;
238   for (auto& w : words_) {
239     if (w.type == type) counts.push_back(w.count);
240   }
241   return counts;
242 }
243 //增加ngram,
244 void Dictionary::addNgrams(std::vector<int32_t>& line, int32_t n) const {
245   int32_t line_size = line.size();
246   for (int32_t i = 0; i < line_size; i++) {
247     uint64_t h = line[i];
248     for (int32_t j = i + 1; j < line_size && j < i + n; j++) {
249       h = h * 116049371 + line[j];
250       line.push_back(nwords_ + (h % args_->bucket));
251     }
252   }
253 }
254 //获取词行
255 int32_t Dictionary::getLine(std::istream& in,
256                             std::vector<int32_t>& words,
257                             std::vector<int32_t>& labels,
258                             std::minstd_rand& rng) const {
259   std::uniform_real_distribution<> uniform(0, 1);//均匀随机0~1
260   std::string token;
261   int32_t ntokens = 0;
262   words.clear();
263   labels.clear();
264   if (in.eof()) {
265     in.clear();
266     in.seekg(std::streampos(0));
267   }
268   while (readWord(in, token)) {
269     if (token == EOS) break;//表示一行的结束
270     int32_t wid = getId(token);
271     if (wid < 0) continue;//表示词的id木有,代表未知词,则跳过
272     entry_type type = getType(wid);
273     ntokens++;//已经获取词数
274     if (type == entry_type::word && !discard(wid, uniform(rng))) {//随机采取样,表示是否取该词
275       words.push_back(wid);//词的收集--词肯定在nwords_以下
276     }
277     if (type == entry_type::label) {//标签词全部采取,肯定在nwords_以上
278       labels.push_back(wid - nwords_);//也就是labels的值需要加上nwords才能够寻找到标签词
279     }
280     if (words.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;//词向量则有限制句子长度
281   }
282   return ntokens;
283 }
284 //获取标签词,根据的是标签词的lid
285 std::string Dictionary::getLabel(int32_t lid) const {//标签词
286   assert(lid >= 0);
287   assert(lid < nlabels_);
288   return words_[lid + nwords_].word;
289 }
290 //保存词典
291 void Dictionary::save(std::ostream& out) const {
292   out.write((char*) &size_, sizeof(int32_t));
293   out.write((char*) &nwords_, sizeof(int32_t));
294   out.write((char*) &nlabels_, sizeof(int32_t));
295   out.write((char*) &ntokens_, sizeof(int64_t));
296   for (int32_t i = 0; i < size_; i++) {//
297     entry e = words_[i];
298     out.write(e.word.data(), e.word.size() * sizeof(char));//
299     out.put(0);//字符串结束标志位
300     out.write((char*) &(e.count), sizeof(int64_t));
301     out.write((char*) &(e.type), sizeof(entry_type));
302   }
303 }
304 //加载词典
305 void Dictionary::load(std::istream& in) {
306   words_.clear();
307   for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
308     word2int_[i] = -1;
309   }
310   in.read((char*) &size_, sizeof(int32_t));
311   in.read((char*) &nwords_, sizeof(int32_t));
312   in.read((char*) &nlabels_, sizeof(int32_t));
313   in.read((char*) &ntokens_, sizeof(int64_t));
314   for (int32_t i = 0; i < size_; i++) {
315     char c;
316     entry e;
317     while ((c = in.get()) != 0) {
318       e.word.push_back(c);
319     }
320     in.read((char*) &e.count, sizeof(int64_t));
321     in.read((char*) &e.type, sizeof(entry_type));
322     words_.push_back(e);
323     word2int_[find(e.word)] = i;//建立索引
324   }
325   initTableDiscard();//初始化抛弃规则
326   initNgrams();//初始化ngram词
327 }
328 
329 }
dictionary.cc

个人觉得有必要说明的地方:

1:关于字符串映射过程,以及如何建立一套索引的,详情见下图:涉及的函数主要是find,内部实现需要hash函数建立hash规则,借助2个vector来进行关联。StrToHash(find函数)   HashToIndex(word2int数组)   IndexToStruct(words_数组)

2:初始化几个有用的表,目的是加速运行速度

1)初始化ngram表,即每个词都对应一个ngram的表的id列表。比如词 "我想你" ,通过computeNgrams函数可以计算出相应ngram的词索引,假设ngram的词最短为2,最长为3,则就是"<我","我想","想你","你>",<我想","我想你","想你>"的子词组成,这里有"<>"因为这里会自动添加这样的词的开始和结束位。这里注意代码实现中的"(word[j] & 0xC0) == 0x80)"这里是考虑utf-8的汉字情况,来使得能够取出完整的一个汉字作为一个"字"

2) 初始化initTableDiscard表,对每个词根据词的频率获取相应的丢弃概率值,若是给定的阈值小于这个表的值那么就丢弃该词,这里是因为对于频率过高的词可能就是无用词,所以丢弃。比如"的","是"等;这里的实现与论文中有点差异,这里是当表中的词小于某个值表示该丢弃,这里因为这里没有对其求1-p形式,而是p+p^2。若是同理转为同方向,则论文是p,现实是p+p^2,这样的做法是使得打压更加宽松点,也就是更多词会被当作无用词丢弃。(不知道原因)

3:外界使用该.cc的主线,一是readFromFile函数,加载词;二是getLine,获取句的词。

类似的vector.cc,matrix.cc,args.cc等代码解析如下:

  1 /**
  2  * Copyright (c) 2016-present, Facebook, Inc.
  3  * All rights reserved.
  4  *
  5  * This source code is licensed under the BSD-style license found in the
  6  * LICENSE file in the root directory of this source tree. An additional grant
  7  * of patent rights can be found in the PATENTS file in the same directory.
  8  */
  9 
 10 #include "matrix.h"
 11 
 12 #include <assert.h>
 13 
 14 #include <random>
 15 
 16 #include "utils.h"
 17 #include "vector.h"
 18 
 19 namespace fasttext {
 20 
 21 Matrix::Matrix() {
 22   m_ = 0;
 23   n_ = 0;
 24   data_ = nullptr;
 25 }
 26 
 27 Matrix::Matrix(int64_t m, int64_t n) {
 28   m_ = m;
 29   n_ = n;
 30   data_ = new real[m * n];
 31 }
 32 
 33 Matrix::Matrix(const Matrix& other) {
 34   m_ = other.m_;
 35   n_ = other.n_;
 36   data_ = new real[m_ * n_];
 37   for (int64_t i = 0; i < (m_ * n_); i++) {
 38     data_[i] = other.data_[i];
 39   }
 40 }
 41 
 42 Matrix& Matrix::operator=(const Matrix& other) {
 43   Matrix temp(other);
 44   m_ = temp.m_;
 45   n_ = temp.n_;
 46   std::swap(data_, temp.data_);
 47   return *this;
 48 }
 49 
 50 Matrix::~Matrix() {
 51   delete[] data_;
 52 }
 53 
 54 void Matrix::zero() {
 55   for (int64_t i = 0; i < (m_ * n_); i++) {
 56       data_[i] = 0.0;
 57   }
 58 }
 59 //随机初始化矩阵-均匀随机
 60 void Matrix::uniform(real a) {
 61   std::minstd_rand rng(1);
 62   std::uniform_real_distribution<> uniform(-a, a);
 63   for (int64_t i = 0; i < (m_ * n_); i++) {
 64     data_[i] = uniform(rng);
 65   }
 66 }
 67 //加向量
 68 void Matrix::addRow(const Vector& vec, int64_t i, real a) {
 69   assert(i >= 0);
 70   assert(i < m_);
 71   assert(vec.m_ == n_);
 72   for (int64_t j = 0; j < n_; j++) {
 73     data_[i * n_ + j] += a * vec.data_[j];
 74   }
 75 }
 76 //点乘向量
 77 real Matrix::dotRow(const Vector& vec, int64_t i) {
 78   assert(i >= 0);
 79   assert(i < m_);
 80   assert(vec.m_ == n_);
 81   real d = 0.0;
 82   for (int64_t j = 0; j < n_; j++) {
 83     d += data_[i * n_ + j] * vec.data_[j];
 84   }
 85   return d;
 86 }
 87 //存储
 88 void Matrix::save(std::ostream& out) {
 89   out.write((char*) &m_, sizeof(int64_t));
 90   out.write((char*) &n_, sizeof(int64_t));
 91   out.write((char*) data_, m_ * n_ * sizeof(real));
 92 }
 93 //加载
 94 void Matrix::load(std::istream& in) {
 95   in.read((char*) &m_, sizeof(int64_t));
 96   in.read((char*) &n_, sizeof(int64_t));
 97   delete[] data_;
 98   data_ = new real[m_ * n_];
 99   in.read((char*) data_, m_ * n_ * sizeof(real));
100 }
101 
102 }
matrix.cc
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree. An additional grant
 * of patent rights can be found in the PATENTS file in the same directory.
 */

#include "vector.h"

#include <assert.h>

#include <iomanip>

#include "matrix.h"
#include "utils.h"

namespace fasttext {

Vector::Vector(int64_t m) {
  m_ = m;
  data_ = new real[m];
}

Vector::~Vector() {
  delete[] data_;
}

int64_t Vector::size() const {
  return m_;
}

void Vector::zero() {
  for (int64_t i = 0; i < m_; i++) {
    data_[i] = 0.0;
  }
}
//数乘向量
void Vector::mul(real a) {
  for (int64_t i = 0; i < m_; i++) {
    data_[i] *= a;
  }
}
//向量相加
void Vector::addRow(const Matrix& A, int64_t i) {
  assert(i >= 0);
  assert(i < A.m_);
  assert(m_ == A.n_);
  for (int64_t j = 0; j < A.n_; j++) {
    data_[j] += A.data_[i * A.n_ + j];
  }
}
//加数乘向量
void Vector::addRow(const Matrix& A, int64_t i, real a) {
  assert(i >= 0);
  assert(i < A.m_);
  assert(m_ == A.n_);
  for (int64_t j = 0; j < A.n_; j++) {
    data_[j] += a * A.data_[i * A.n_ + j];
  }
}
//向量与矩阵相乘得到的向量
void Vector::mul(const Matrix& A, const Vector& vec) {
  assert(A.m_ == m_);
  assert(A.n_ == vec.m_);
  for (int64_t i = 0; i < m_; i++) {
    data_[i] = 0.0;
    for (int64_t j = 0; j < A.n_; j++) {
      data_[i] += A.data_[i * A.n_ + j] * vec.data_[j];
    }
  }
}
//最大分量
int64_t Vector::argmax() {
  real max = data_[0];
  int64_t argmax = 0;
  for (int64_t i = 1; i < m_; i++) {
    if (data_[i] > max) {
      max = data_[i];
      argmax = i;
    }
  }
  return argmax;
}

real& Vector::operator[](int64_t i) {
  return data_[i];
}

const real& Vector::operator[](int64_t i) const {
  return data_[i];
}

std::ostream& operator<<(std::ostream& os, const Vector& v)
{
  os << std::setprecision(5);
  for (int64_t j = 0; j < v.m_; j++) {
    os << v.data_[j] << ' ';
  }
  return os;
}

}
vector.cc
  1 /**
  2  * Copyright (c) 2016-present, Facebook, Inc.
  3  * All rights reserved.
  4  *
  5  * This source code is licensed under the BSD-style license found in the
  6  * LICENSE file in the root directory of this source tree. An additional grant
  7  * of patent rights can be found in the PATENTS file in the same directory.
  8  */
  9 
 10 #include "args.h"
 11 
 12 #include <stdlib.h>
 13 #include <string.h>
 14 
 15 #include <iostream>
 16 
 17 namespace fasttext {
 18 
 19 Args::Args() {
 20   lr = 0.05;
 21   dim = 100;
 22   ws = 5;
 23   epoch = 5;
 24   minCount = 5;
 25   minCountLabel = 0;
 26   neg = 5;
 27   wordNgrams = 1;
 28   loss = loss_name::ns;
 29   model = model_name::sg;
 30   bucket = 2000000;//允许的ngram词典大小2M
 31   minn = 3;
 32   maxn = 6;
 33   thread = 12;
 34   lrUpdateRate = 100;
 35   t = 1e-4;//默认
 36   label = "__label__";
 37   verbose = 2;
 38   pretrainedVectors = "";
 39 }
 40 
 41 void Args::parseArgs(int argc, char** argv) {
 42   std::string command(argv[1]);
 43   if (command == "supervised") {
 44     model = model_name::sup;
 45     loss = loss_name::softmax;
 46     minCount = 1;
 47     minn = 0;
 48     maxn = 0;
 49     lr = 0.1;
 50   } else if (command == "cbow") {
 51     model = model_name::cbow;
 52   }
 53   int ai = 2;
 54   while (ai < argc) {
 55     if (argv[ai][0] != '-') {
 56       std::cout << "Provided argument without a dash! Usage:" << std::endl;
 57       printHelp();
 58       exit(EXIT_FAILURE);
 59     }
 60     if (strcmp(argv[ai], "-h") == 0) {
 61       std::cout << "Here is the help! Usage:" << std::endl;
 62       printHelp();
 63       exit(EXIT_FAILURE);
 64     } else if (strcmp(argv[ai], "-input") == 0) {
 65       input = std::string(argv[ai + 1]);
 66     } else if (strcmp(argv[ai], "-test") == 0) {
 67       test = std::string(argv[ai + 1]);
 68     } else if (strcmp(argv[ai], "-output") == 0) {
 69       output = std::string(argv[ai + 1]);
 70     } else if (strcmp(argv[ai], "-lr") == 0) {
 71       lr = atof(argv[ai + 1]);
 72     } else if (strcmp(argv[ai], "-lrUpdateRate") == 0) {
 73       lrUpdateRate = atoi(argv[ai + 1]);
 74     } else if (strcmp(argv[ai], "-dim") == 0) {
 75       dim = atoi(argv[ai + 1]);
 76     } else if (strcmp(argv[ai], "-ws") == 0) {
 77       ws = atoi(argv[ai + 1]);
 78     } else if (strcmp(argv[ai], "-epoch") == 0) {
 79       epoch = atoi(argv[ai + 1]);
 80     } else if (strcmp(argv[ai], "-minCount") == 0) {
 81       minCount = atoi(argv[ai + 1]);
 82     } else if (strcmp(argv[ai], "-minCountLabel") == 0) {
 83       minCountLabel = atoi(argv[ai + 1]);
 84     } else if (strcmp(argv[ai], "-neg") == 0) {
 85       neg = atoi(argv[ai + 1]);
 86     } else if (strcmp(argv[ai], "-wordNgrams") == 0) {
 87       wordNgrams = atoi(argv[ai + 1]);
 88     } else if (strcmp(argv[ai], "-loss") == 0) {
 89       if (strcmp(argv[ai + 1], "hs") == 0) {
 90         loss = loss_name::hs;
 91       } else if (strcmp(argv[ai + 1], "ns") == 0) {
 92         loss = loss_name::ns;
 93       } else if (strcmp(argv[ai + 1], "softmax") == 0) {
 94         loss = loss_name::softmax;
 95       } else {
 96         std::cout << "Unknown loss: " << argv[ai + 1] << std::endl;
 97         printHelp();
 98         exit(EXIT_FAILURE);
 99       }
100     } else if (strcmp(argv[ai], "-bucket") == 0) {
101       bucket = atoi(argv[ai + 1]);
102     } else if (strcmp(argv[ai], "-minn") == 0) {
103       minn = atoi(argv[ai + 1]);
104     } else if (strcmp(argv[ai], "-maxn") == 0) {
105       maxn = atoi(argv[ai + 1]);
106     } else if (strcmp(argv[ai], "-thread") == 0) {
107       thread = atoi(argv[ai + 1]);
108     } else if (strcmp(argv[ai], "-t") == 0) {
109       t = atof(argv[ai + 1]);
110     } else if (strcmp(argv[ai], "-label") == 0) {
111       label = std::string(argv[ai + 1]);
112     } else if (strcmp(argv[ai], "-verbose") == 0) {
113       verbose = atoi(argv[ai + 1]);
114     } else if (strcmp(argv[ai], "-pretrainedVectors") == 0) {
115       pretrainedVectors = std::string(argv[ai + 1]);
116     } else {
117       std::cout << "Unknown argument: " << argv[ai] << std::endl;
118       printHelp();
119       exit(EXIT_FAILURE);
120     }
121     ai += 2;
122   }
123   if (input.empty() || output.empty()) {
124     std::cout << "Empty input or output path." << std::endl;
125     printHelp();
126     exit(EXIT_FAILURE);
127   }
128   if (wordNgrams <= 1 && maxn == 0) {
129     bucket = 0;
130   }
131 }
132 
133 void Args::printHelp() {
134   std::string lname = "ns";
135   if (loss == loss_name::hs) lname = "hs";
136   if (loss == loss_name::softmax) lname = "softmax";
137   std::cout
138     << "\n"
139     << "The following arguments are mandatory:\n"
140     << "  -input              training file path\n"
141     << "  -output             output file path\n\n"
142     << "The following arguments are optional:\n"
143     << "  -lr                 learning rate [" << lr << "]\n"
144     << "  -lrUpdateRate       change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
145     << "  -dim                size of word vectors [" << dim << "]\n"
146     << "  -ws                 size of the context window [" << ws << "]\n"
147     << "  -epoch              number of epochs [" << epoch << "]\n"
148     << "  -minCount           minimal number of word occurences [" << minCount << "]\n"
149     << "  -minCountLabel      minimal number of label occurences [" << minCountLabel << "]\n"
150     << "  -neg                number of negatives sampled [" << neg << "]\n"
151     << "  -wordNgrams         max length of word ngram [" << wordNgrams << "]\n"
152     << "  -loss               loss function {ns, hs, softmax} [ns]\n"
153     << "  -bucket             number of buckets [" << bucket << "]\n"
154     << "  -minn               min length of char ngram [" << minn << "]\n"
155     << "  -maxn               max length of char ngram [" << maxn << "]\n"
156     << "  -thread             number of threads [" << thread << "]\n"
157     << "  -t                  sampling threshold [" << t << "]\n"
158     << "  -label              labels prefix [" << label << "]\n"
159     << "  -verbose            verbosity level [" << verbose << "]\n"
160     << "  -pretrainedVectors  pretrained word vectors for supervised learning []"
161     << std::endl;
162 }
163 
164 void Args::save(std::ostream& out) {
165   out.write((char*) &(dim), sizeof(int));
166   out.write((char*) &(ws), sizeof(int));
167   out.write((char*) &(epoch), sizeof(int));
168   out.write((char*) &(minCount), sizeof(int));
169   out.write((char*) &(neg), sizeof(int));
170   out.write((char*) &(wordNgrams), sizeof(int));
171   out.write((char*) &(loss), sizeof(loss_name));
172   out.write((char*) &(model), sizeof(model_name));
173   out.write((char*) &(bucket), sizeof(int));
174   out.write((char*) &(minn), sizeof(int));
175   out.write((char*) &(maxn), sizeof(int));
176   out.write((char*) &(lrUpdateRate), sizeof(int));
177   out.write((char*) &(t), sizeof(double));
178 }
179 
180 void Args::load(std::istream& in) {
181   in.read((char*) &(dim), sizeof(int));
182   in.read((char*) &(ws), sizeof(int));
183   in.read((char*) &(epoch), sizeof(int));
184   in.read((char*) &(minCount), sizeof(int));
185   in.read((char*) &(neg), sizeof(int));
186   in.read((char*) &(wordNgrams), sizeof(int));
187   in.read((char*) &(loss), sizeof(loss_name));
188   in.read((char*) &(model), sizeof(model_name));
189   in.read((char*) &(bucket), sizeof(int));
190   in.read((char*) &(minn), sizeof(int));
191   in.read((char*) &(maxn), sizeof(int));
192   in.read((char*) &(lrUpdateRate), sizeof(int));
193   in.read((char*) &(t), sizeof(double));
194 }
195 
196 }
args.cc

三:model.cc

/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree. An additional grant
 * of patent rights can be found in the PATENTS file in the same directory.
 */

#include "model.h"

#include <assert.h>

#include <algorithm>

#include "utils.h"

namespace fasttext {

Model::Model(std::shared_ptr<Matrix> wi,
             std::shared_ptr<Matrix> wo,
             std::shared_ptr<Args> args,
             int32_t seed)
  : hidden_(args->dim), output_(wo->m_), grad_(args->dim), rng(seed)
{
  wi_ = wi;//输入--上下文
  wo_ = wo;//参数矩阵,行对应于某个词的参数集合
  args_ = args;//参数
  isz_ = wi->m_;
  osz_ = wo->m_;
  hsz_ = args->dim;
  negpos = 0;
  loss_ = 0.0;
  nexamples_ = 1;
  initSigmoid();
  initLog();
}

Model::~Model() {
  delete[] t_sigmoid;
  delete[] t_log;
}
//小型逻辑回归
real Model::binaryLogistic(int32_t target, bool label, real lr) {
  real score = sigmoid(wo_->dotRow(hidden_, target));//获取sigmod,某一行的-target==== q
  real alpha = lr * (real(label) - score);//若是正样本,则1,否则是0=================  g
  grad_.addRow(*wo_, target, alpha);//更新中间值                                    == e
  wo_->addRow(hidden_, target, alpha);//更新参数
  if (label) {//记录损失值----根据公式来的,L=log(1/p(x))  ,p(x)是概率值
    return -log(score);//p(x)=score
  } else {
    return -log(1.0 - score);//p(x)=1-score    score表示为1的概率
  }
}
//负采样的方式
real Model::negativeSampling(int32_t target, real lr) {//target表示目标词的index
  real loss = 0.0;
  grad_.zero();//e值的设置为0
  for (int32_t n = 0; n <= args_->neg; n++) {//负采样的比例,这里数目
    if (n == 0) {//正样例
      loss += binaryLogistic(target, true, lr);
    } else {//负样例--neg 个
      loss += binaryLogistic(getNegative(target), false, lr);
    }
  }
  return loss;
}
//层次softmax
real Model::hierarchicalSoftmax(int32_t target, real lr) {
  real loss = 0.0;
  grad_.zero();
  const std::vector<bool>& binaryCode = codes[target];
  const std::vector<int32_t>& pathToRoot = paths[target];
  for (int32_t i = 0; i < pathToRoot.size(); i++) {//根据编码路劲搞,词到根目录的
    loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr);
  }
  return loss;
}
//计算softmax值,存入output中
void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
  output.mul(*wo_, hidden);//向量乘以矩阵---输出=参数转移矩阵*输入
  real max = output[0], z = 0.0;
  for (int32_t i = 0; i < osz_; i++) {//获取最大的内积值
    max = std::max(output[i], max);
  }
  for (int32_t i = 0; i < osz_; i++) {//求出每个内积值相对最大值的情况
    output[i] = exp(output[i] - max);
    z += output[i];//累计和,用于归一化
  }
  for (int32_t i = 0; i < osz_; i++) {//求出softmax值
    output[i] /= z;
  }
}

void Model::computeOutputSoftmax() {
  computeOutputSoftmax(hidden_, output_);
}
//普通softmax计算
real Model::softmax(int32_t target, real lr) {
  grad_.zero();
  computeOutputSoftmax();
  for (int32_t i = 0; i < osz_; i++) {//遍历所有词---此次操作只是针对一个词的更新
    real label = (i == target) ? 1.0 : 0.0;
    real alpha = lr * (label - output_[i]);//中间参数
    grad_.addRow(*wo_, i, alpha);//更新e值
    wo_->addRow(hidden_, i, alpha);//更新参数
  }
  return -log(output_[target]);//损失值
}
//计算映射层的向量
void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) const {
  assert(hidden.size() == hsz_);
  hidden.zero();
  for (auto it = input.cbegin(); it != input.cend(); ++it) {//指定的行进行累加,也就是上下文的词向量
    hidden.addRow(*wi_, *it);
  }
  hidden.mul(1.0 / input.size());//求均值为Xw
}
//比较,按照第一个降序
bool Model::comparePairs(const std::pair<real, int32_t> &l,
                         const std::pair<real, int32_t> &r) {
  return l.first > r.first;
}
//模型预测函数
void Model::predict(const std::vector<int32_t>& input, int32_t k,
                    std::vector<std::pair<real, int32_t>>& heap,
                    Vector& hidden, Vector& output) const {
  assert(k > 0);
  heap.reserve(k + 1);
  computeHidden(input, hidden);//计算映射层,input是上下文
  if (args_->loss == loss_name::hs) {//层次softmax,遍历树结构
    dfs(k, 2 * osz_ - 2, 0.0, heap, hidden);
  } else {//其他则通过数组寻最大
    findKBest(k, heap, hidden, output);
  }
  std::sort_heap(heap.begin(), heap.end(), comparePairs);//堆排序,得到最终的排序的值,降序排
}

void Model::predict(const std::vector<int32_t>& input, int32_t k,
                    std::vector<std::pair<real, int32_t>>& heap) {
  predict(input, k, heap, hidden_, output_);
}
//vector寻找topk---获得一个最小堆
void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
                      Vector& hidden, Vector& output) const {
  computeOutputSoftmax(hidden, output);//计算soft值
  for (int32_t i = 0; i < osz_; i++) {//输出的大小
    if (heap.size() == k && log(output[i]) < heap.front().first) {//小于topk中最小的那个,最小堆,损失值
      continue;
    }
    heap.push_back(std::make_pair(log(output[i]), i));//加入堆中
    std::push_heap(heap.begin(), heap.end(), comparePairs);//做对排序
    if (heap.size() > k) {//
      std::pop_heap(heap.begin(), heap.end(), comparePairs);//移动最小的那个到最后面,且堆排序
      heap.pop_back();//删除最后一个元素
    }
  }
}
//层次softmax的topk获取
void Model::dfs(int32_t k, int32_t node, real score,
                std::vector<std::pair<real, int32_t>>& heap,
                Vector& hidden) const {//从根开始
  if (heap.size() == k && score < heap.front().first) {//跳过
    return;
  }

  if (tree[node].left == -1 && tree[node].right == -1) {//表示为叶子节点
    heap.push_back(std::make_pair(score, node));//根到叶子的损失总值,叶子也就是词了
    std::push_heap(heap.begin(), heap.end(), comparePairs);//维持最小堆,以损失值
    if (heap.size() > k) {
      std::pop_heap(heap.begin(), heap.end(), comparePairs);
      heap.pop_back();
    }
    return;
  }

  real f = sigmoid(wo_->dotRow(hidden, node - osz_));//计算出sigmod值,用于计算损失
  dfs(k, tree[node].left, score + log(1.0 - f), heap, hidden);//左侧为1损失
  dfs(k, tree[node].right, score + log(f), heap, hidden);
}
//更新操作
void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
  assert(target >= 0);
  assert(target < osz_);
  if (input.size() == 0) return;
  computeHidden(input, hidden_);//计算映射层值
  if (args_->loss == loss_name::ns) {//负采样的更新
    loss_ += negativeSampling(target, lr);
  } else if (args_->loss == loss_name::hs) {//层次soft
    loss_ += hierarchicalSoftmax(target, lr);
  } else {//普通soft
    loss_ += softmax(target, lr);
  }
  nexamples_ += 1;//处理的样例数,

  if (args_->model == model_name::sup) {//分类
    grad_.mul(1.0 / input.size());
  }
  for (auto it = input.cbegin(); it != input.cend(); ++it) {//获取指向常数的指针
    wi_->addRow(grad_, *it, 1.0);//迭代加上上下文的词向量,来更新上下文的词向量
  }
}
//根据词频的向量,构建哈夫曼树或者初始化负采样的表
void Model::setTargetCounts(const std::vector<int64_t>& counts) {
  assert(counts.size() == osz_);
  if (args_->loss == loss_name::ns) {
    initTableNegatives(counts);
  }
  if (args_->loss == loss_name::hs) {
    buildTree(counts);
  }
}
//负采样的采样表获取
void Model::initTableNegatives(const std::vector<int64_t>& counts) {
  real z = 0.0;
  for (size_t i = 0; i < counts.size(); i++) {
    z += pow(counts[i], 0.5);//采取是词频的0.5次方
  }
  for (size_t i = 0; i < counts.size(); i++) {
    real c = pow(counts[i], 0.5);//c值
    //0,0,0,1,1,1,1,1,1,1,2,2类似这种有序的,0表示第一个词,占个坑,随机读取时,越多则概率越大。所有词的随机化
    //最多重复次数,若是c/z足够小,会导致重复次数很少,最小是1次
    //NEGATIVE_TABLE_SIZE含义是一个词最多重复不能够超过的值
    for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {//该词映射到表的维度上的取值情况,也就是不等分区映射到等区分段上
      negatives.push_back(i);
    }
  }
  std::shuffle(negatives.begin(), negatives.end(), rng);//随机化一下,均匀随机化,
}
//对于词target获取负采样的值
int32_t Model::getNegative(int32_t target) {
  int32_t negative;
  do {
    negative = negatives[negpos];//由于表是随机化的,取值就是随机采的
    negpos = (negpos + 1) % negatives.size();//下一个,不断的累加的,由于表格随机的,所以不需要pos随机了
  } while (target == negative);//若是遇到为正样本则跳过
  return negative;
}
//构建哈夫曼树过程
void Model::buildTree(const std::vector<int64_t>& counts) {
  tree.resize(2 * osz_ - 1);
  for (int32_t i = 0; i < 2 * osz_ - 1; i++) {
    tree[i].parent = -1;
    tree[i].left = -1;
    tree[i].right = -1;
    tree[i].count = 1e15;
    tree[i].binary = false;
  }
  for (int32_t i = 0; i < osz_; i++) {
    tree[i].count = counts[i];
  }
  int32_t leaf = osz_ - 1;
  int32_t node = osz_;
  for (int32_t i = osz_; i < 2 * osz_ - 1; i++) {
    int32_t mini[2];
    for (int32_t j = 0; j < 2; j++) {
      if (leaf >= 0 && tree[leaf].count < tree[node].count) {
        mini[j] = leaf--;
      } else {
        mini[j] = node++;
      }
    }
    tree[i].left = mini[0];
    tree[i].right = mini[1];
    tree[i].count = tree[mini[0]].count + tree[mini[1]].count;
    tree[mini[0]].parent = i;
    tree[mini[1]].parent = i;
    tree[mini[1]].binary = true;
  }
  for (int32_t i = 0; i < osz_; i++) {
    std::vector<int32_t> path;
    std::vector<bool> code;
    int32_t j = i;
    while (tree[j].parent != -1) {
      path.push_back(tree[j].parent - osz_);
      code.push_back(tree[j].binary);
      j = tree[j].parent;
    }
    paths.push_back(path);
    codes.push_back(code);
  }
}
//获取均匀损失值,平均每个样本的损失
real Model::getLoss() const {
  return loss_ / nexamples_;
}
//初始化sigmod表
void Model::initSigmoid() {
  t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1];
  for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
    real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
    t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x));
  }
}
//初始化log函数的表,对于0~1之间的值
void Model::initLog() {
  t_log = new real[LOG_TABLE_SIZE + 1];
  for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
    real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
    t_log[i] = std::log(x);
  }
}
//log的处理
real Model::log(real x) const {
  if (x > 1.0) {
    return 0.0;
  }
  int i = int(x * LOG_TABLE_SIZE);
  return t_log[i];
}
//获取sigmod值
real Model::sigmoid(real x) const {
  if (x < -MAX_SIGMOID) {
    return 0.0;
  } else if (x > MAX_SIGMOID) {
    return 1.0;
  } else {
    int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
    return t_sigmoid[i];
  }
}

}
model.cc

说明:

1:模型核心在于模型的更新即update函数,此时函数根据不同参数,选择不同的模型训练方法,共提供了3种方式

2:前两种方式的公有处理方式的提取,由于前两种方式的共有的更新。区别度在于选择部分词,还是将词累到共公节点上

四:fasttext.cc

/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree. An additional grant
 * of patent rights can be found in the PATENTS file in the same directory.
 */

#include "fasttext.h"

#include <math.h>

#include <iostream>
#include <iomanip>
#include <thread>
#include <string>
#include <vector>
#include <algorithm>

namespace fasttext {
//获取词向量
void FastText::getVector(Vector& vec, const std::string& word) {
  const std::vector<int32_t>& ngrams = dict_->getNgrams(word);
  vec.zero();
  for (auto it = ngrams.begin(); it != ngrams.end(); ++it) {
    vec.addRow(*input_, *it);//ngram的累加
  }
  if (ngrams.size() > 0) {//ngram均值,来体现词向量
    vec.mul(1.0 / ngrams.size());
  }
}
//保存词向量
void FastText::saveVectors() {
  std::ofstream ofs(args_->output + ".vec");
  if (!ofs.is_open()) {
    std::cout << "Error opening file for saving vectors." << std::endl;
    exit(EXIT_FAILURE);
  }
  ofs << dict_->nwords() << " " << args_->dim << std::endl;
  Vector vec(args_->dim);
  for (int32_t i = 0; i < dict_->nwords(); i++) {
    std::string word = dict_->getWord(i);//获取词
    getVector(vec, word);//获取词的向量
    ofs << word << " " << vec << std::endl;
  }
  ofs.close();
}
//保存模型
void FastText::saveModel() {
  std::ofstream ofs(args_->output + ".bin", std::ofstream::binary);
  if (!ofs.is_open()) {
    std::cerr << "Model file cannot be opened for saving!" << std::endl;
    exit(EXIT_FAILURE);
  }
  args_->save(ofs);
  dict_->save(ofs);
  input_->save(ofs);
  output_->save(ofs);
  ofs.close();
}
//加载模型
void FastText::loadModel(const std::string& filename) {
  std::ifstream ifs(filename, std::ifstream::binary);
  if (!ifs.is_open()) {
    std::cerr << "Model file cannot be opened for loading!" << std::endl;
    exit(EXIT_FAILURE);
  }
  loadModel(ifs);
  ifs.close();
}

void FastText::loadModel(std::istream& in) {
  args_ = std::make_shared<Args>();
  dict_ = std::make_shared<Dictionary>(args_);
  input_ = std::make_shared<Matrix>();
  output_ = std::make_shared<Matrix>();
  args_->load(in);
  dict_->load(in);
  input_->load(in);
  output_->load(in);
  model_ = std::make_shared<Model>(input_, output_, args_, 0);//传的是指针,改变可以带回
  if (args_->model == model_name::sup) {//构建模型的过程
    model_->setTargetCounts(dict_->getCounts(entry_type::label));
  } else {
    model_->setTargetCounts(dict_->getCounts(entry_type::word));
  }
}
//打印提示信息
void FastText::printInfo(real progress, real loss) {
  real t = real(clock() - start) / CLOCKS_PER_SEC;//多少秒
  real wst = real(tokenCount) / t;//每秒处理词数
  real lr = args_->lr * (1.0 - progress);//学习率
  int eta = int(t / progress * (1 - progress) / args_->thread);
  int etah = eta / 3600;
  int etam = (eta - etah * 3600) / 60;
  std::cout << std::fixed;
  std::cout << "\rProgress: " << std::setprecision(1) << 100 * progress << "%";//完成度
  std::cout << "  words/sec/thread: " << std::setprecision(0) << wst;//每秒每线程处理个数
  std::cout << "  lr: " << std::setprecision(6) << lr;//学习率
  std::cout << "  loss: " << std::setprecision(6) << loss;//损失度
  std::cout << "  eta: " << etah << "h" << etam << "m ";
  std::cout << std::flush;
}

void FastText::supervised(Model& model, real lr,
                          const std::vector<int32_t>& line,
                          const std::vector<int32_t>& labels) {
  if (labels.size() == 0 || line.size() == 0) return;
  std::uniform_int_distribution<> uniform(0, labels.size() - 1);
  int32_t i = uniform(model.rng);
  model.update(line, labels[i], lr);
}
//cbow模型
void FastText::cbow(Model& model, real lr,
                    const std::vector<int32_t>& line) {
  std::vector<int32_t> bow;
  std::uniform_int_distribution<> uniform(1, args_->ws);
  for (int32_t w = 0; w < line.size(); w++) {
    int32_t boundary = uniform(model.rng);//随机取个窗口--每个词的窗口不一样
    bow.clear();
    for (int32_t c = -boundary; c <= boundary; c++) {
      if (c != 0 && w + c >= 0 && w + c < line.size()) {
        const std::vector<int32_t>& ngrams = dict_->getNgrams(line[w + c]);//ngrams语言
        bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend());//加入上下文中
      }
    }
    model.update(bow, line[w], lr);//根据上下文更新
  }
}
//skipgram模型
void FastText::skipgram(Model& model, real lr,
                        const std::vector<int32_t>& line) {
  std::uniform_int_distribution<> uniform(1, args_->ws);
  for (int32_t w = 0; w < line.size(); w++) {
    int32_t boundary = uniform(model.rng);//窗口随机
    const std::vector<int32_t>& ngrams = dict_->getNgrams(line[w]);
    for (int32_t c = -boundary; c <= boundary; c++) {//每个预测词的更新
      if (c != 0 && w + c >= 0 && w + c < line.size()) {
        model.update(ngrams, line[w + c], lr);//ngram作为上下文
      }
    }
  }
}
//测试模型
void FastText::test(std::istream& in, int32_t k) {
  int32_t nexamples = 0, nlabels = 0;
  double precision = 0.0;
  std::vector<int32_t> line, labels;

  while (in.peek() != EOF) {
    dict_->getLine(in, line, labels, model_->rng);//获取句子
    dict_->addNgrams(line, args_->wordNgrams);//对句子增加其ngram
    if (labels.size() > 0 && line.size() > 0) {
      std::vector<std::pair<real, int32_t>> modelPredictions;
      model_->predict(line, k, modelPredictions);//预测
      for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
        if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
          precision += 1.0;//准确数
        }
      }
      nexamples++;
      nlabels += labels.size();
    }
  }
  std::cout << std::setprecision(3);
  std::cout << "P@" << k << ": " << precision / (k * nexamples) << std::endl;
  std::cout << "R@" << k << ": " << precision / nlabels << std::endl;
  std::cout << "Number of examples: " << nexamples << std::endl;
}
//预测
void FastText::predict(std::istream& in, int32_t k,
                       std::vector<std::pair<real,std::string>>& predictions) const {
  std::vector<int32_t> words, labels;
  dict_->getLine(in, words, labels, model_->rng);
  dict_->addNgrams(words, args_->wordNgrams);
  if (words.empty()) return;
  Vector hidden(args_->dim);
  Vector output(dict_->nlabels());
  std::vector<std::pair<real,int32_t>> modelPredictions;
  model_->predict(words, k, modelPredictions, hidden, output);
  predictions.clear();
  for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
    predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second)));//不同标签的预测分
  }
}
//预测
void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
  std::vector<std::pair<real,std::string>> predictions;
  while (in.peek() != EOF) {
    predict(in, k, predictions);
    if (predictions.empty()) {
      std::cout << "n/a" << std::endl;
      continue;
    }
    for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
      if (it != predictions.cbegin()) {
        std::cout << ' ';
      }
      std::cout << it->second;
      if (print_prob) {
        std::cout << ' ' << exp(it->first);
      }
    }
    std::cout << std::endl;
  }
}
//获取词向量
void FastText::wordVectors() {
  std::string word;
  Vector vec(args_->dim);
  while (std::cin >> word) {
    getVector(vec, word);//获取一个词的词向量,不仅仅是对已知的,还能对未知进行预测
    std::cout << word << " " << vec << std::endl;
  }
}
//句子的向量
void FastText::textVectors() {
  std::vector<int32_t> line, labels;
  Vector vec(args_->dim);
  while (std::cin.peek() != EOF) {
    dict_->getLine(std::cin, line, labels, model_->rng);//句子
    dict_->addNgrams(line, args_->wordNgrams);//对应ngram
    vec.zero();
    for (auto it = line.cbegin(); it != line.cend(); ++it) {//句子的词以及ngram的索引
      vec.addRow(*input_, *it);//将词的向量求出和
    }
    if (!line.empty()) {//求均值
      vec.mul(1.0 / line.size());
    }
    std::cout << vec << std::endl;//表示句子的词向量
  }
}

void FastText::printVectors() {
  if (args_->model == model_name::sup) {
    textVectors();
  } else {//词向量
    wordVectors();
  }
}
//训练线程
void FastText::trainThread(int32_t threadId) {
  std::ifstream ifs(args_->input);
  utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);

  Model model(input_, output_, args_, threadId);
  if (args_->model == model_name::sup) {
    model.setTargetCounts(dict_->getCounts(entry_type::label));
  } else {
    model.setTargetCounts(dict_->getCounts(entry_type::word));
  }

  const int64_t ntokens = dict_->ntokens();
  int64_t localTokenCount = 0;
  std::vector<int32_t> line, labels;
  while (tokenCount < args_->epoch * ntokens) {//epoch迭代次数
    real progress = real(tokenCount) / (args_->epoch * ntokens);//进度
    real lr = args_->lr * (1.0 - progress);
    localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
    if (args_->model == model_name::sup) {//分不同函数进行处理
      dict_->addNgrams(line, args_->wordNgrams);
      supervised(model, lr, line, labels);
    } else if (args_->model == model_name::cbow) {
      cbow(model, lr, line);
    } else if (args_->model == model_name::sg) {
      skipgram(model, lr, line);
    }
    if (localTokenCount > args_->lrUpdateRate) {//修正学习率
      tokenCount += localTokenCount;
      localTokenCount = 0;
      if (threadId == 0 && args_->verbose > 1) {
        printInfo(progress, model.getLoss());
      }
    }
  }
  if (threadId == 0 && args_->verbose > 0) {
    printInfo(1.0, model.getLoss());
    std::cout << std::endl;
  }
  ifs.close();
}
//加载Vectors过程, 字典
void FastText::loadVectors(std::string filename) {
  std::ifstream in(filename);
  std::vector<std::string> words;
  std::shared_ptr<Matrix> mat; // temp. matrix for pretrained vectors
  int64_t n, dim;
  if (!in.is_open()) {
    std::cerr << "Pretrained vectors file cannot be opened!" << std::endl;
    exit(EXIT_FAILURE);
  }
  in >> n >> dim;
  if (dim != args_->dim) {
    std::cerr << "Dimension of pretrained vectors does not match -dim option"
              << std::endl;
    exit(EXIT_FAILURE);
  }
  mat = std::make_shared<Matrix>(n, dim);
  for (size_t i = 0; i < n; i++) {
    std::string word;
    in >> word;
    words.push_back(word);
    dict_->add(word);
    for (size_t j = 0; j < dim; j++) {
      in >> mat->data_[i * dim + j];
    }
  }
  in.close();

  dict_->threshold(1, 0);
  input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
  input_->uniform(1.0 / args_->dim);

  for (size_t i = 0; i < n; i++) {
    int32_t idx = dict_->getId(words[i]);
    if (idx < 0 || idx >= dict_->nwords()) continue;
    for (size_t j = 0; j < dim; j++) {
      input_->data_[idx * dim + j] = mat->data_[i * dim + j];
    }
  }
}
//训练
void FastText::train(std::shared_ptr<Args> args) {
  args_ = args;
  dict_ = std::make_shared<Dictionary>(args_);
  if (args_->input == "-") {
    // manage expectations
    std::cerr << "Cannot use stdin for training!" << std::endl;
    exit(EXIT_FAILURE);
  }
  std::ifstream ifs(args_->input);
  if (!ifs.is_open()) {
    std::cerr << "Input file cannot be opened!" << std::endl;
    exit(EXIT_FAILURE);
  }
  dict_->readFromFile(ifs);
  ifs.close();

  if (args_->pretrainedVectors.size() != 0) {
    loadVectors(args_->pretrainedVectors);
  } else {
    input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
    input_->uniform(1.0 / args_->dim);
  }

  if (args_->model == model_name::sup) {
    output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
  } else {
    output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
  }
  output_->zero();

  start = clock();
  tokenCount = 0;
  std::vector<std::thread> threads;
  for (int32_t i = 0; i < args_->thread; i++) {
    threads.push_back(std::thread([=]() { trainThread(i); }));
  }
  for (auto it = threads.begin(); it != threads.end(); ++it) {
    it->join();
  }
  model_ = std::make_shared<Model>(input_, output_, args_, 0);

  saveModel();
  if (args_->model != model_name::sup) {
    saveVectors();
  }
}

}
fasttext.cc

 

posted @ 2016-12-25 15:34  miner007  阅读(1784)  评论(0编辑  收藏  举报