Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练
Alink漫谈(十七) :Word2Vec源码分析 之 迭代训练
0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文和上文将带领大家来分析Alink中 Word2Vec 的实现。
因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。
0x01 前文回顾
从前文 Alink漫谈(十六) :Word2Vec之建立霍夫曼树 我们了解了Word2Vec的概念、在Alink中的整体架构以及完成对输入的处理,以及词典、二叉树的建立。
此时我们已经有了一个已经构造好的Huffman树,以及初始化完毕的各个向量,可以开始输入文本来进行训练了。
1.1 上文总体流程图
先给出一个上文总体流程图:
1.2 回顾霍夫曼树
1.2.1 变量定义
现在定义变量如下:
- n : 一个词的上下文包含的词数,与n-gram中n的含义相同
- m : 词向量的长度,通常在10~100
- h : 隐藏层的规模,一般在100量级
- N :词典的规模,通常在1W~10W
- T : 训练文本中单词个数
1.2.2 为何要引入霍夫曼树
word2vec也使用了CBOW与Skip-Gram来训练模型与得到词向量,但是并没有使用传统的DNN模型。最先优化使用的数据结构是用霍夫曼树来代替隐藏层和输出层的神经元,霍夫曼树的叶子节点起到输出层神经元的作用,叶子节点的个数即为词汇表的小大。 而内部节点则起到隐藏层神经元的作用。
以CBOW为例,输入层为n-1个单词的词向量,长度为m(n-1),隐藏层的规模为h,输出层的规模为N。那么前向的时间复杂度就是o(m(n-1)h+hN) = o(hN) 这还是处理一个词所需要的复杂度。如果要处理所有文本,则需要o(hNT)的时间复杂度。这个是不可接受的。
同时我们也注意到,o(hNT)之中,h和T的值相对固定,想要对其进行优化,主要还是应该从N入手。而输出层的规模之所以为N,是因为这个神经网络要完成的是N选1的任务。那么可不可以减小N的值呢?答案是可以的。解决的思路就是将一次分类分解为多次分类,这也是Hierarchical Softmax的核心思想。
举个栗子,有[1,2,3,4,5,6,7,8]这8个分类,想要判断词A属于哪个分类,我们可以一步步来,首先判断A是属于[1,2,3,4]还是属于[5,6,7,8]。如果判断出属于[1,2,3,4],那么就进一步分析是属于[1,2]还是[3,4],以此类推。这样一来,就把单个词的时间复杂度从o(hN)降为o(hlogN),更重要的减少了内存的开销。
从输入到输出,中间是一个树形结构,其中的每一个节点都完成一个二分类(logistic分类)问题。那么就存在一个如何构建树的问题。这里采用huffman树,因为这样构建的话,出现频率越高的词所经过的路径越短,从而使得所有单词的平均路径长度达到最短。
假设 足球 的路径 1001,那么当要输出 足球 这个词的时候,这个模型其实并不是直接输出 "1001" 这条路径,而是在每一个节点都进行一次二分类。这样相当于将最后输出的二叉树变成多个二分类的任务。而路径中的每个根节点都是一个待求的向量。 也就是说这个模型不仅需要求每个输入参数的变量,还需要求这棵二叉树中每个非叶子节点的向量,当然这些向量都只是临时用的向量。
0x02 训练
2.1 训练流程
Alink 实现的是 基于Hierarchical Softmax的Skip-Gram模型。
现在我们先看看基于Skip-Gram模型时, Hierarchical Softmax如何使用。此时输入的只有一个词w,输出的为2c个词向量context(w)。
训练的过程主要有输入层(input),映射层(projection)和输出层(output)三个阶段。
- 我们对于训练样本中的每一个词,该词本身作为样本的输入,其前面的c个词和后面的c个词作为了Skip-Gram模型的输出,期望这些词的softmax概率比其他的词大。
- 我们需要先将词汇表建立成一颗霍夫曼树(此步骤在上文已经完成)。
- 对于从输入层到隐藏层(映射层),这一步比CBOW简单,由于只有一个词,所以,即x_w就是词w对应的词向量。
- 通过梯度上升法来更新我们的θwj−1和x_w,注意这里的x_w周围有2c个词向量,我们期望P(xi|xw),i=1,2...2c最大。此时我们注意到由于上下文是相互的,在期望P(xi|xw),i=1,2...2c最大化的同时,反过来我们也期望P(xw|xi),i=1,2...2c最大。那么是使用P(xi|xw)好还是P(xw|xi)好呢,word2vec使用了后者,这样做的好处就是在一个迭代窗口内,我们不是只更新xw一个词,而是xi,i=1,2...2c共2c个词。这样整体的迭代会更加的均衡。因为这个原因,Skip-Gram模型并没有和CBOW模型一样对输入进行迭代更新,而是对2c个输出进行迭代更新。
- 从根节点开始,映射层的值需要沿着Huffman树不断的进行logistic分类,并且不断的修正各中间向量和词向量。
- 假设映射层输入为 pro(t),单词为“足球”,即w(t)=“足球”,假设其Huffman码可知为d(t)=”1001,那么根据Huffman码可知,从根节点到叶节点的路径为“左右右左”,即从根节点开始,先往左拐,再往右拐2次,最后再左拐。
- 既然知道了路径,那么就按照路径从上往下依次修正路径上各节点的中间向量。在第一个节点,根据节点的中间向量Θ(t,1)和pro(t)进行Logistic分类。如果分类结果显示为0,则表示分类错误 (应该向左拐,即分类到1),则要对Θ(t,1)进行修正,并记录误差量。
- 接下来,处理完第一个节点之后,开始处理第二个节点。方法类似,修正Θ(t,2),并累加误差量。接下来的节点都以此类推。
- 在处理完所有节点,达到叶节点之后,根据之前累计的误差来修正词向量v(w(t))。这里引入学习率概念,η 表示学习率。学习率越大,则判断错误的惩罚也越大,对中间向量的修正跨度也越大。
这样,一个词w(t)的处理流程就结束了。如果一个文本中有N个词,则需要将上述过程在重复N遍,从w(0)~w(N-1)。
这里总结下基于Hierarchical Softmax的Skip-Gram模型算法流程,梯度迭代使用了随机梯度上升法:
- 输入:基于Skip-Gram的语料训练样本,词向量的维度大小M,Skip-Gram的上下文大小2c,步长η
- 输出:霍夫曼树的内部节点模型参数θ,所有的词向量w
2.2 生成训练模型
Huffman树中非叶节点存储的中间向量的初始化值是零向量,而叶节点对应的单词的词向量是随机初始化的。我们可以看到,对于 input 和 output,是会进行 AllReduce 的,就是聚合每个task的计算结果。
DataSet <Row> model = new IterativeComQueue()
.initWithPartitionedData("trainData", trainData)
.initWithBroadcastData("vocSize", vocSize)
.initWithBroadcastData("initialModel", initialModel)
.initWithBroadcastData("vocabWithoutWordStr", vocabWithoutWordStr)
.initWithBroadcastData("syncNum", syncNum)
.add(new InitialVocabAndBuffer(getParams()))
.add(new UpdateModel(getParams()))
.add(new AllReduce("input"))
.add(new AllReduce("output"))
.add(new AvgInputOutput())
.setCompareCriterionOfNode0(new Criterion(getParams()))
.closeWith(new SerializeModel(getParams()))
.exec();
2.3 初始化词典&缓冲
InitialVocabAndBuffer类完成此功能,主要是初始化参数,把词典加载到模型内存中。这里只有迭代第一次才会运行。
- input 数组存着Vocab的全部词向量,就是Huffman树所有叶子节点的词向量。大小|V|∗|M|,初始化范围[−0.5M,0.5M],经验规则。
- output 数组存着Hierarchical Softmax的参数,就是Huffman树所有 非叶子 节点的参数向量(映射层到输出层之间的权重)。大小|V|∗|M|,初始化全为0,经验规则。实际使用|V−1|组。
private static class InitialVocabAndBuffer extends ComputeFunction {
Params params;
public InitialVocabAndBuffer(Params params) {
this.params = params;
}
@Override
public void calc(ComContext context) {
if (context.getStepNo() == 1) { // 只有迭代第一次才会运行
int vectorSize = params.get(Word2VecTrainParams.VECTOR_SIZE);
List <Long> vocSizeList = context.getObj("vocSize");
List <Tuple2 <Integer, double[]>> initialModel = context.getObj("initialModel");
List <Tuple2 <Integer, Word>> vocabWithoutWordStr = context.getObj("vocabWithoutWordStr");
int vocSize = vocSizeList.get(0).intValue();
// 生成一个 100 x 12 的input,这个迭代之后就是最终的词向量
double[] input = new double[vectorSize * vocSize];
Word[] vocab = new Word[vocSize];
for (int i = 0; i < vocSize; ++i) {
Tuple2 <Integer, double[]> item = initialModel.get(i);
System.arraycopy(item.f1, 0, input,
item.f0 * vectorSize, vectorSize); //初始化词向量
Tuple2 <Integer, Word> vocabItem = vocabWithoutWordStr.get(i);
vocab[vocabItem.f0] = vocabItem.f1;
}
context.putObj("input", input); // 把词向量放入系统上下文
// 生成一个 100 x 11 的output,就是Hierarchical Softmax的参数
context.putObj("output", new double[vectorSize * (vocSize - 1)]);
context.putObj("vocab", vocab);
context.removeObj("initialModel");
context.removeObj("vocabWithoutWordStr");
}
}
}
2.4 更新模型UpdateModel
这里进行“分布式计算”的分配。其中,如何计算给哪个task发送多少/发送起始位置,是在DefaultDistributedInfo完成的。这里需要结合 pieces 函数进行分析。具体在 [ Alink漫谈之三] AllReduce通信模型 有详细介绍。具体计算则是在 CalcModel.update 中完成。
private static class UpdateModel extends ComputeFunction {
@Override
public void calc(ComContext context) {
List <int[]> trainData = context.getObj("trainData");
int syncNum = ((List <Integer>) context.getObj("syncNum")).get(0);
DistributedInfo distributedInfo = new DefaultDistributedInfo();
long startPos = distributedInfo.startPos(
(context.getStepNo() - 1) % syncNum,
syncNum,
trainData.size()
);
long localRowCnt = distributedInfo.localRowCnt( //计算本分区信息
(context.getStepNo() - 1) % syncNum,
syncNum,
trainData.size()
);
new CalcModel( //更新模型
params.get(Word2VecTrainParams.VECTOR_SIZE),
System.currentTimeMillis(),
Boolean.parseBoolean(params.get(Word2VecTrainParams.RANDOM_WINDOW)),
params.get(Word2VecTrainParams.WINDOW),
params.get(Word2VecTrainParams.ALPHA),
context.getTaskId(),
context.getObj("vocab"),
context.getObj("input"),
context.getObj("output")
).update(trainData.subList((int) startPos, (int) (startPos + localRowCnt)));
}
}
2.5 计算更新
CalcModel.update 中完成计算更新。
2.5.1 sigmoid函数值近似计算
在利用神经网络模型对样本进行预測的过程中。须要对其进行预測,此时,须要使用到sigmoid函数,sigmoid函数的具体形式为:σ(x)=1 / (1+e^x)
。
σ(x) 在 x = 0 附近变化剧烈,往两边逐渐趋于平缓,当 x > 6 或者 x < -6 时候,函数值就基本不变了,前者趋近于0,后者趋近于1.
假设每一次都请求计算sigmoid值,对性能将会有一定的影响,当sigmoid的值对精度的要求并非非常严格时。能够採用近似计算。在word2vec中。将区间[−6,6](设置的參数MAX_EXP为6)等距离划分成EXP_TABLE_SIZE等份,并将每个区间中的sigmoid值计算好存入到数组expTable中。须要使用时,直接从数组中查找。
Alink中实现如下:
public class ExpTableArray {
public final static float[] sigmoidTable = {
0.002473f, 0.002502f, 0.002532f, 0.002562f, 0.002593f, 0.002624f, 0.002655f, 0.002687f, 0.002719f, 0.002751f ......
}
}
2.5.2 窗口及上下文
Context(w) 就是在词 w 的前后各取 C 个词,Alink是事先设置一个窗口预置参数window(默认为5),每次构造Context(w)时候,首先生成一个[1, window] 上的一个随机整数 C~ ,于是 w 前后各取 C~ 个词就构成了 Context(w)。
if (randomWindow) {
b = random.nextInt(window);
} else {
b = 0;
}
int bound = window * 2 + 1 - b;
for (int a = b; a < bound; ++a) {
.....
}
2.5.3 训练
2.5.3.1 数据结构
在 c语言代码 中:
- syn0数组存着Vocab的全部词向量,就是Huffman树所有叶子节点的词向量,即input -> hidden 的 weights 。大小|V|∗|M|,初始化范围[−0.5M,0.5M],经验规则。在code中是一个1维数组,但是应该按照二维数组来理解。访问时实际上可以看成 syn0[i, j] i为第i个单词,j为第j个隐含单元。
- syn1数组存着Hierarchical Softmax的参数,就是Huffman树所有 非叶子 节点的参数向量,即 hidden----> output 的 weights。大小|V|∗|M|,初始化全为0,经验规则。实际使用|V−1|组。
原本的Softmax问题,被近似退化成了近似log(K)个Logistic回归组合成决策树。
Softmax的K组θ,现在变成了K-1组,代表着二叉树的K-1个非叶结点。在Word2Vec中,由syn1数组存放,。
在 Alink代码 中:
- input 就对应了 syn0,就是上图的 v。
- output 就对应了 syn1,就是上图的 θ。
2.5.3.2 具体代码
具体代码如下(我们使用最大似然法来寻找所有节点的词向量和所有内部节点θ):
private static class CalcModel {
public void update(List <int[]> values) {
double[] neu1e = new double[vectorSize];
double f, g;
int b, c, lastWord, l1, l2;
for (int[] val : values) {
for (int i = 0; i < val.length; ++i) {
if (randomWindow) {
b = random.nextInt(window);
} else {
b = 0;
}
// 在Skip-gram模型中。须要使用当前词分别预測窗体中的词,因此。这是一个循环的过程
// 因为需要预测Context(w)中的每个词,因此需要循环2window - 2b + 1次遍历整个窗口
int bound = window * 2 + 1 - b;
for (int a = b; a < bound; ++a) {
if (a != window) { //遍历时跳过中心单词
c = i - window + a;
if (c < 0 || c >= val.length) {
continue;
}
lastWord = val[c]; //last_word为当前待预测的上下文单词
l1 = lastWord * vectorSize; //l1为当前单词的词向量在syn0中的起始位置
Arrays.fill(neu1e, 0.f); //初始化累计误差
Word w = vocab[val[i]];
int codeLen = w.code.length;
//根据Haffman树上从根节点到当前词的叶节点的路径,遍历所有经过的中间节点
for (int d = 0; d < codeLen; ++d) {
f = 0.f;
//l2为当前遍历到的中间节点的向量在syn1中的起始位置
l2 = w.point[d] * vectorSize;
// 正向传播,得到该编码单元对应的output 值f
//注意!这里用到了模型对称:p(u|w) = p(w|u),其中w为中心词,u为context(w)中每个词, 也就是skip-gram虽然是给中心词预测上下文,真正训练的时候还是用上下文预测中心词, 与CBOW不同的是这里的u是单个词的词向量,而不是窗口向量之和
// 将路径上所有Node连锁起来,累积得到 输入向量与中间结点向量的内积
// f=σ(W.θi)
for (int t = 0; t < vectorSize; ++t) {
// 这里就是 X * Y
// 映射层即为输入层
f += input[l1 + t] * output[l2 + t];
}
if (f > -6.0f && f < 6.0f) {
// 从 ExpTableArray 中查询到相应的值。
f = ExpTableArray.sigmoidTable[(int) ((f + 6.0) * 84.0)];
//@brief此处最核心,loss是交叉熵 Loss=xlogp(x)+(1-x)*log(1-p(x))
//其中p(x)=exp(neu1[c] * syn1[c + l2])/(1+exp(neu1[c] * syn1[c + l2]))
//x=1-code#作者才此处定义label为1-code,实际上也可以是code
//log(L) = (1-x) * neu1[c] * syn1[c + l2] -x*log(1 + exp(neu1[c] * syn1[c + l2]))
//对log(L)中的syn1进行偏导,g=(1 -code - p(x))*syn1
//因此会有
//g = (1 - vocab[word].code[d] - f) * alpha;alpha学习速率
// 'g' is the gradient multiplied by the learning rate
// g是梯度和学习率的乘积
//注意!word2vec中将Haffman编码为1的节点定义为负类,而将编码为0的节点定义为正类,即一个节点的label = 1 - d
g = (1.f - w.code[d] - f) * alpha;
// Propagate errors output -> hidden
// 根据计算得到的修正量g和中间节点的向量更新累计误差
for (int t = 0; t < vectorSize; ++t) {
neu1e[t] += g * output[l2 + t]; // 修改映射后的结果
}
// Learn weights hidden -> output
for (int t = 0; t < vectorSize; ++t) {
output[l2 + t] += g * input[l1 + t]; // 改动映射层到输出层之间的权重
}
}
}
for (int t = 0; t < vectorSize; ++t) {
input[l1 + t] += neu1e[t]; // 返回改动每个词向量
}
}
}
}
}
}
}
2.6 平均化
AvgInputOutput 类会对Input,output做平均化。
.add(new AllReduce("input"))
.add(new AllReduce("output"))
.add(new AvgInputOutput())
原因在于做AllReduce时候,会简单的累积,如果有 context.getNumTask() 个task在同时进行,就容易简单粗暴的相加,这样数值就会扩大 context.getNumTask() 倍。
private static class AvgInputOutput extends ComputeFunction {
@Override
public void calc(ComContext context) {
double[] input = context.getObj("input");
for (int i = 0; i < input.length; ++i) {
input[i] /= context.getNumTask(); //平均化
}
double[] output = context.getObj("output");
for (int i = 0; i < output.length; ++i) {
output[i] /= context.getNumTask(); //平均化
}
}
}
2.7 判断收敛
这里能够看到,收敛就是判断是否达到迭代次数。
private static class Criterion extends CompareCriterionFunction {
@Override
public boolean calc(ComContext context) {
return (context.getStepNo() - 1)
== ((List <Integer>) context.getObj("syncNum")).get(0)
* params.get(Word2VecTrainParams.NUM_ITER);
}
}
2.8 序列化模型
这是在 context.getTaskId() 为 0 的task中完成序列化操作,其他task直接返回。这里收集了所有task的计算结果。
private static class SerializeModel extends CompleteResultFunction {
@Override
public List <Row> calc(ComContext context) {
// 在 context.getTaskId() 为 0 的task中完成序列化操作,其他task直接返回
if (context.getTaskId() != 0) {
return null; //其他task直接返回
}
int vocSize = ((List <Long>) context.getObj("vocSize")).get(0).intValue();
int vectorSize = params.get(Word2VecTrainParams.VECTOR_SIZE);
List <Row> ret = new ArrayList <>(vocSize);
double[] input = context.getObj("input");
for (int i = 0; i < vocSize; ++i) {
// 完成序列化操作
DenseVector dv = new DenseVector(vectorSize);
System.arraycopy(input, i * vectorSize, dv.getData(), 0, vectorSize);
ret.add(Row.of(i, dv));
}
return ret;
}
}
0x03 输出模型
输出模型的代码如下,功能分别是:
- 把词典和计算出来的向量联系起来
- 按分区分割模型成row
- 发送模型
model = model
.map(new MapFunction <Row, Tuple2 <Integer, DenseVector>>() {
@Override
public Tuple2 <Integer, DenseVector> map(Row value) throws Exception {
return Tuple2.of((Integer) value.getField(0), (DenseVector) value.getField(1));
}
})
.join(vocab)
.where(0)
.equalTo(0) //把词典和计算出来的向量联系起来
.with(new JoinFunction <Tuple2 <Integer, DenseVector>, Tuple3 <Integer, String, Word>, Row>() {
@Override
public Row join(Tuple2 <Integer, DenseVector> first, Tuple3 <Integer, String, Word> second)
throws Exception {
return Row.of(second.f1, first.f1);
}
})
.mapPartition(new MapPartitionFunction <Row, Row>() {
@Override
public void mapPartition(Iterable <Row> values, Collector <Row> out) throws Exception {
Word2VecModelDataConverter model = new Word2VecModelDataConverter();
model.modelRows = StreamSupport
.stream(values.spliterator(), false)
.collect(Collectors.toList());
model.save(model, out);
}
});
setOutput(model, new Word2VecModelDataConverter().getModelSchema());
3.1 联系词典和向量
.join(vocab).where(0).equalTo(0)
就是把词典和计算出来的向量联系起来。两个Join来源分别如下:
// 来源1,计算出来的向量
first = {Tuple2@11501}
f0 = {Integer@11509} 9
f1 = {DenseVector@11502} "0.9371751984171548 0.33341686580829943 0.6472255126130384 0.36692156358000316 0.1187895685629788 0.9223451469664975 0.763874142430857 0.1330720374498615 0.9631811135902764 0.9283700030050634......"
// 来源2,词典
second = {Tuple3@11499} "(9,我们,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@1ffa469)"
f0 = {Integer@11509} 9
f1 = "我们"
f2 = {Word2VecTrainBatchOp$Word@11510}
3.2 按分区分割模型成row
首先按照分区计算,分割模型成row。这里用到了java 8的新特性 StreamSupport,spliterator。
但是这里只是使用到了Stream的方式,没有使用其并行功能(可能过后会有文章进行研究)。
model.modelRows = StreamSupport
.stream(values.spliterator(), false)
.collect(Collectors.toList());
比如某个分区得到:
model = {Word2VecModelDataConverter@11561}
modelRows = {ArrayList@11562} size = 3
0 = {Row@11567} "胖,0.4345151137723066 0.4923534386513069 0.49497589358976174 0.10917632806760409 0.7007392318076214 0.6468149904858065 0.3804865818632239 0.4348997489483902 0.03362685646645655 0.29769437681180916 0.04287936035337748..."
1 = {Row@11568} "的,0.4347763498886036 0.6852891840621573 0.9862851622413142 0.7061202166493431 0.9896492612656784 0.46525497532250026 0.03379287230189395 0.809333161215095 0.9230387687661015 0.5100444513892355 0.02436724648194081..."
2 = {Row@11569} "老王,0.4337285110643647 0.7605192699353084 0.6638406386520266 0.909594031681524 0.26995654043189604 0.3732722125930673 0.16171135697228312 0.9759668223869069 0.40331291071231623 0.22651841541002585 0.7150087001048662...."
......
3.3 发送数据
然后发送数据
public class Word2VecModelDataConverter implements ModelDataConverter<Word2VecModelDataConverter, Word2VecModelDataConverter> {
public List <Row> modelRows;
@Override
public void save(Word2VecModelDataConverter modelData, Collector<Row> collector) {
modelData.modelRows.forEach(collector::collect); //发送数据
}
@Override
public TableSchema getModelSchema() {
return new TableSchema( //返回schema
new String[] {"word", "vec"},
new TypeInformation[] {Types.STRING, VectorTypes.VECTOR}
);
}
}
0x04 问题答案
我们上文提到了一些问题,现在逐一回答:
- 哪些模块用到了Alink的分布式处理能力?答案是:
- 分割单词,计数(为了剔除低频词,排序);
- 单词排序;
- 训练;
- Alink实现了Word2vec的哪个模型?是CBOW模型还是skip-gram模型?答案是:
- skip-gram模型
- Alink用到了哪个优化方法?是Hierarchical Softmax?还是Negative Sampling?答案是:
- Hierarchical Softmax
- 是否在本算法内去除停词?所谓停用词,就是出现频率太高的词,如逗号,句号等等,以至于没有区分度。答案是:
- 本实现中没有去处停词
- 是否使用了自适应学习率?答案是:
- 没有
0xFF 参考
word2vec原理(二) 基于Hierarchical Softmax的模型
word2vec原理(一) CBOW与Skip-Gram模型基础
word2vec原理(三) 基于Negative Sampling的模型