Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer
阿新 • • 發佈:2020-08-15
# Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer
[ToC]
## 0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中 MultiStringIndexer 的實現。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
本文緣由是想分析GBDT,發現GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。
## 0x01 概念
Alink的官方介紹是:MultiStringIndexer訓練元件的作用是訓練一個模型用於將多列字串對映為整數。
具體來說,StringIndexer(字串-索引變換)將標籤的"字串列"編碼為"標籤索引的列"。
- 標籤索引序列的取值範圍是[0,numLabels(字串中所有出現的單詞去掉重複的詞後的總和)],按照標籤出現頻率排序,出現最多的標籤索引為0(具體為升序降序是可以配置的)。
- 如果輸入是數值型,我們先將數值對映到字串,再對字串進行索引化。
- 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化後的標籤序列,則需要將這個pipeline的輸入列名字指定為索引化序列的名字。大部分情況下,通過setSelectedCols設定輸入的列名。
以這些輸入為例:
```java
("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")
```
對於第一列,MultiStringIndexer 對資料集的label進行重新編號。按label出現的頻次,轉換成0 ~ numOfLabels - 1(分類個數)。如果是按照從高到低排序,則頻次最高的轉換為0,以此類推,比如:
- football,出現次數最多,出現了3次,轉換(編號)為0
- 其次是basketball,出現了2次,編號為1,以此類推。
在應用StringIndexer對labels進行重新編號後,帶著這些編號後的label對資料進行了訓練,並接著對其他資料進行了預測,得到預測結果,預測結果的label也是重新編號過的,因此需要轉換回來。
## 0x02 示例程式碼
示例程式碼如下,本示例程式碼中,是按照升序排列,即football總數為3,則其idx為3,tennis個數為1,其idx為0:
```java
public class MultiStringIndexerExample {
static AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] {
Row.of("football", "can"),
Row.of("football", "hhh"),
Row.of("football", "zzz"),
Row.of("basketball", "zzz"),
Row.of("basketball", "can"),
Row.of("tennis", "can")
};
if (isBatch) {
return new MemSourceBatchOp(
Arrays.asList(array), new String[] {"a", "b"});
} else {
return new MemSourceStreamOp(
Arrays.asList(array), new String[] {"a", "b"});
}
}
public static void main(String[] args) throws Exception {
BatchOperator data = (BatchOperator)getData(true);
MultiStringIndexer stringindexer = new MultiStringIndexer()
.setSelectedCols("a", "b")
.setOutputCols("a_indexed", "b_indexed")
.setStringOrderType("frequency_asc");
stringindexer.fit(data).transform(data).print();
}
}
```
輸出如下:
```java
a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2
```
轉換成表格看的更清楚。
| a | b | a_indexed | b_indexed |
| ---------- | ---- | --------- | --------- |
| football | can | 2 | 2 |
| football | hhh | 2 | 0 |
| football | zzz | 2 | 1 |
| basketball | zzz | 1 | 1 |
| basketball | can | 1 | 2 |
| tennis | can | 0 | 2 |
## 0x03 總體邏輯
我們先給出一個流程圖
![](https://img2020.cnblogs.com/blog/1850883/202008/1850883-20200815074945552-187376175.png)
老套路,我們從 MultiStringIndexerTrainBatchOp.linkFrom開始挖掘。
```java
@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator... inputs) {
BatchOperator in = checkAndGetFirst(inputs);
// 示例中有 .setSelectedCols("a", "b"),這裡是取出具體列名字
final String[] selectedColNames = getSelectedCols();
// 獲取列的型別
final String[] selectedColSqlType = new String[selectedColNames.length];
for (int i = 0; i < selectedColNames.length; i++) {
selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
}
// runtime列印資料
selectedColNames = {String[2]@2536}
0 = "a"
1 = "b"
selectedColSqlType = {String[2]@2537}
0 = "VARCHAR"
1 = "VARCHAR"
// 獲取選取列對應的資料