Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer
Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer
0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 MultiStringIndexer 的实现。
因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。
本文缘由是想分析GBDT,发现GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。
0x01 概念
Alink的官方介绍是:MultiStringIndexer训练组件的作用是训练一个模型用于将多列字符串映射为整数。
具体来说,StringIndexer(字符串-索引变换)将标签的"字符串列"编码为"标签索引的列"。
- 标签索引序列的取值范围是[0,numLabels(字符串中所有出现的单词去掉重复的词后的总和)],按照标签出现频率排序,出现最多的标签索引为0(具体为升序降序是可以配置的)。
- 如果输入是数值型,我们先将数值映射到字符串,再对字符串进行索引化。
- 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化后的标签序列,则需要将这个pipeline的输入列名字指定为索引化序列的名字。大部分情况下,通过setSelectedCols设置输入的列名。
以这些输入为例:
("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")
对于第一列,MultiStringIndexer 对数据集的label进行重新编号。按label出现的频次,转换成0 ~ numOfLabels - 1(分类个数)。如果是按照从高到低排序,则频次最高的转换为0,以此类推,比如:
- football,出现次数最多,出现了3次,转换(编号)为0
- 其次是basketball,出现了2次,编号为1,以此类推。
在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。
0x02 示例代码
示例代码如下,本示例代码中,是按照升序排列,即football总数为3,则其idx为3,tennis个数为1,其idx为0:
public class MultiStringIndexerExample {
static AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] {
Row.of("football", "can"),
Row.of("football", "hhh"),
Row.of("football", "zzz"),
Row.of("basketball", "zzz"),
Row.of("basketball", "can"),
Row.of("tennis", "can")
};
if (isBatch) {
return new MemSourceBatchOp(
Arrays.asList(array), new String[] {"a", "b"});
} else {
return new MemSourceStreamOp(
Arrays.asList(array), new String[] {"a", "b"});
}
}
public static void main(String[] args) throws Exception {
BatchOperator data = (BatchOperator)getData(true);
MultiStringIndexer stringindexer = new MultiStringIndexer()
.setSelectedCols("a", "b")
.setOutputCols("a_indexed", "b_indexed")
.setStringOrderType("frequency_asc");
stringindexer.fit(data).transform(data).print();
}
}
输出如下:
a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2
转换成表格看的更清楚。
a | b | a_indexed | b_indexed |
---|---|---|---|
football | can | 2 | 2 |
football | hhh | 2 | 0 |
football | zzz | 2 | 1 |
basketball | zzz | 1 | 1 |
basketball | can | 1 | 2 |
tennis | can | 0 | 2 |
0x03 总体逻辑
我们先给出一个流程图
老套路,我们从 MultiStringIndexerTrainBatchOp.linkFrom开始挖掘。
@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
// 示例中有 .setSelectedCols("a", "b"),这里是取出具体列名字
final String[] selectedColNames = getSelectedCols();
// 获取列的类型
final String[] selectedColSqlType = new String[selectedColNames.length];
for (int i = 0; i < selectedColNames.length; i++) {
selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
}
// runtime打印数据
selectedColNames = {String[2]@2536}
0 = "a"
1 = "b"
selectedColSqlType = {String[2]@2537}
0 = "VARCHAR"
1 = "VARCHAR"
// 获取选取列对应的数据
DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
//
DataSet<Tuple3<Integer, String, Long>> indexedToken =
StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true);
DataSet<Row> values = indexedToken
.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
@Override
public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
// 第一个task会做这个计算,就是把列名,列类型作为元数据传送
meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
.set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
}
// runtime打印数据
meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
params = {HashMap@9316} size = 2
new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
}
})
.name("build_model");
this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
return this;
}
训练过程总体逻辑总结如下:
- 取出具体列名字,列的类型;
- 获取"选取列"对应的数据;
- 把列名,列类型作为元数据传送;
StringIndexerUtil.indexTokens
给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关;- 调用到
indexSortedByFreq(data, startIndex, ignoreNull, true)
,作用是给各个列的不同字串赋予连续的indices,indices是按照字符串出现的频率排序;- 调用到 countTokens的作用是按照 "列idx","word" 来合并计算单词个数,得到<"列idx","word",单词个数>,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。
- 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
- 对上面结构进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ;
- 按照 "列idx","word" 来分组;
- 按照 "列idx","word" 来合并计算单词个数;
- indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;
- 首先按照 列idx 做分组;
- 然后在上面结果基础上,按照单词个数排序;
- 排序的index是以输入参数startIndex开始,startIndex在这里是0;
- 最后得到 第一列的 (0,football,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
- 调用到 countTokens的作用是按照 "列idx","word" 来合并计算单词个数,得到<"列idx","word",单词个数>,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。
- 调用到
- 把indexTokens的结果存储为模型,其中使用之前提到的 "把列名,列类型作为元数据"。
下面具体剖析后两个阶段。
0x04 Add Index to Token
这部分就是给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关。
具体是由StringIndexerUtil.indexTokens 做到的。
public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
final long startIndex, final boolean ignoreNull) {
case FREQUENCY_ASC:
return indexSortedByFreq(data, startIndex, ignoreNull, true);
}
4.1 合并计算单词个数
indexSortedByFreq会调用countTokens来计算单词个数,所以我们先看countTokens。
countTokens的作用是按照 "列idx","word" 来合并计算单词个数,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。
具体逻辑如下:
- 调用 flattenTokens 把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如对于 Row.of("football", "can") 这个输入,flattenTokens 输出两个Tuple2 ,<0, "football"> 和 <1, "can">。
- 对上面结果进行map操作,输出<column idx, word, 1L>,比如 <0, "football", 1L> ,这个是计数的常规操作。
- 按照 "列idx","word" 来分组;
- 按照 "列idx","word" 来合并计算单词个数,就是不停归并上面的 1L。
4.1.1 打散输入数据
其中 flattenTokens 的作用是把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token.。
比如对于 Row.of("football", "can") 这个输入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o)));
输出两个Tuple2。
value = {Row@9212} "football,can"
fields = {Object[2]@9215}
0 = "football"
1 = "can"
输出 <0, "football"> 和 <1, "can">
4.1.2 分组计算个数
这是通过flattenTokens的结果进行 map,groupBy,reduce的一系列操作完成的。
具体代码如下:
public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
return flattenTokens(data, ignoreNull) // 把输入数据 Row 给打散
.map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
return Tuple3.of(value.f0, value.f1, 1L); // 输出<column idx, word, 1L>,比如 <0, "football", 1L>
}
})
.groupBy(0, 1) // 按照 "列idx","word" 来分组
.reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
value1.f2 += value2.f2;
return value1; // 按照 "列idx","word" 来合并计算单词个数
}
})
.name("count_tokens");
}
// reduce之后发出
value1 = {Tuple3@9284} "(0,football,3)"
f0 = {Integer@9226} 0
f1 = "football"
f2 = {Long@9295} 3
4.2 合并计算单词个数
前面 countTokens的 返回三元组是 <列idx","word" ,词频>,其中列的idx从0开始计算。
indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;
- 首先按照 列idx 做分组;
- 然后在上面结果基础上,按照单词个数排序;
- 排序的index是以输入参数startIndex开始,startIndex在这里是0;
- 最后得到 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
具体代码如下:
public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
return countTokens(data, ignoreNull)
.groupBy(0) //按照 列idx 做分组
.sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照单词个数排序
.reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
@Override
public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
Collector<Tuple3<Integer, String, Long>> out) {
long id = startIndex;
for (Tuple3<Integer, String, Long> value : values) {
out.collect(Tuple3.of(value.f0, value.f1, id++)); // 归并
}
}
});
}
0x05 输出模型
这部分分为两部分:
- 输出元数据,就是之前得到的 "把列名,列类型作为元数据"。
- 输出具体每一列的每一个单词信息,比如 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的数据 (1,hhh,0),(1,zzz,1),(1,can,2);
public class MultiStringIndexerModelDataConverter implements
ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
@Override
public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
if (modelData.f0 != null) {
collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
}
modelData.f1.forEach(tuple -> {
collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
});
}
}
tuple = {Tuple3@9405} "(0,tennis,0)"
f0 = {Integer@9406} 0
f1 = "tennis"
f2 = {Long@9408} 0
0x06 预测
预测功能是在 ModelMapperAdapter 完成的。
public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
private final ModelMapper mapper;
private final ModelSource modelSource;
@Override
public void open(Configuration parameters) throws Exception {
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
this.mapper.loadModel(modelRows); //加载模型
}
@Override
public Row map(Row row) throws Exception {
return this.mapper.map(row); //预测
}
}
6.1 加载模型
MultiStringIndexerModelDataConverter中我们会进行模型加载。
- 首先会加载元信息
- 其次会逐条加载模型信息
public MultiStringIndexerModelData load(List<Row> rows) {
MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
modelData.tokenAndIndex = new ArrayList<>();
modelData.tokenNumber = new HashMap<>();
for (Row row : rows) {
long colIndex = (Long) row.getField(0);
if (colIndex < 0L) { // 元数据
modelData.meta = Params.fromJson((String) row.getField(1));
} else { // 具体模型信息
int columnIndex = ((Long) row.getField(0)).intValue();
Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合并列数据个数
}
}
// To ensure that every columns has token number.
int numFields = 0;
if (modelData.meta != null) {
numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
}
for (int i = 0; i < numFields; i++) {
modelData.tokenNumber.merge(i, 0L, Long::sum);
}
return modelData;
}
最后模型内容如下,其中 tokenNumber 表示每列的数据有几个,tokenAndIndex表示具体信息,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他们都是第一列的,basketball转换后的数据是 1:
modelData = {MultiStringIndexerModelData@9348}
meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
tokenAndIndex = {ArrayList@9360} size = 6
0 = {Tuple3@9472} "(0,football,2)"
1 = {Tuple3@9511} "(0,tennis,0)"
2 = {Tuple3@9512} "(1,zzz,1)"
3 = {Tuple3@9513} "(1,hhh,0)"
4 = {Tuple3@9514} "(0,basketball,1)"
5 = {Tuple3@9515} "(1,can,2)"
tokenNumber = {HashMap@9385} size = 2
{Integer@9507} 0 -> {Long@9508} 3
{Integer@9509} 1 -> {Long@9508} 3
numFields = 2
6.2 预测
预测是在 MultiStringIndexerModelMapper 完成的。
// 假设输入是:row = {Row@9309} "football,can"
// 选择的列是:selectedColNames = {String[2]@9314} 0 = "a" 1 = "b"
// 模型映射器是:
this = {MultiStringIndexerModelMapper@9309}
indexMapper = {HashMap@9318} size = 2
{Integer@9357} 0 -> {HashMap@9314} size = 3
key = {Integer@9357} 0
value = 0
value = {HashMap@9314} size = 3
"basketball" -> {Long@9386} 1
"football" -> {Long@9332} 2
"tennis" -> {Long@9384} 0
{Integer@9352} 1 -> {HashMap@9358} size = 3
key = {Integer@9352} 1
value = 1
value = {HashMap@9358} size = 3
"can" -> {Long@9332} 2
"hhh" -> {Long@9384} 0
"zzz" -> {Long@9386} 1
则经历过下列代码,最后就可以进行预测
public Row map(Row row) throws Exception {
Row result = new Row(selectedColNames.length);
for (int i = 0; i < selectedColNames.length; i++) {
Map<String, Long> mapper = indexMapper.get(i);
int colIdxInData = selectedColIndicesInData[i];
Object val = row.getField(colIdxInData);
String key = val == null ? null : String.valueOf(val);
Long index = mapper.get(key);
if (index != null) {
result.setField(i, index); // 我们主要执行在这里
} else {
}
}
// 最后预测结果是:
row = {Row@9308} "football,can"
result = {Row@9313} "2,2"
return outputColsHelper.getResultRow(row, result);
}