1. 程式人生 > >Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer

Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer

# Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer [ToC] ## 0x00 摘要 Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中 MultiStringIndexer 的實現。 因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。 本文緣由是想分析GBDT,發現GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。 ## 0x01 概念 Alink的官方介紹是:MultiStringIndexer訓練元件的作用是訓練一個模型用於將多列字串對映為整數。 具體來說,StringIndexer(字串-索引變換)將標籤的"字串列"編碼為"標籤索引的列"。 - 標籤索引序列的取值範圍是[0,numLabels(字串中所有出現的單詞去掉重複的詞後的總和)],按照標籤出現頻率排序,出現最多的標籤索引為0(具體為升序降序是可以配置的)。 - 如果輸入是數值型,我們先將數值對映到字串,再對字串進行索引化。 - 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化後的標籤序列,則需要將這個pipeline的輸入列名字指定為索引化序列的名字。大部分情況下,通過setSelectedCols設定輸入的列名。 以這些輸入為例: ```java ("football", "can"), ("football", "hhh"), ("football", "zzz"), ("basketball", "zzz"), ("basketball", "can"), ("tennis", "can") ``` 對於第一列,MultiStringIndexer 對資料集的label進行重新編號。按label出現的頻次,轉換成0 ~ numOfLabels - 1(分類個數)。如果是按照從高到低排序,則頻次最高的轉換為0,以此類推,比如: - football,出現次數最多,出現了3次,轉換(編號)為0 - 其次是basketball,出現了2次,編號為1,以此類推。 在應用StringIndexer對labels進行重新編號後,帶著這些編號後的label對資料進行了訓練,並接著對其他資料進行了預測,得到預測結果,預測結果的label也是重新編號過的,因此需要轉換回來。 ## 0x02 示例程式碼 示例程式碼如下,本示例程式碼中,是按照升序排列,即football總數為3,則其idx為3,tennis個數為1,其idx為0: ```java public class MultiStringIndexerExample { static AlgoOperator getData(boolean isBatch) { Row[] array = new Row[] { Row.of("football", "can"), Row.of("football", "hhh"), Row.of("football", "zzz"), Row.of("basketball", "zzz"), Row.of("basketball", "can"), Row.of("tennis", "can") }; if (isBatch) { return new MemSourceBatchOp( Arrays.asList(array), new String[] {"a", "b"}); } else { return new MemSourceStreamOp( Arrays.asList(array), new String[] {"a", "b"}); } } public static void main(String[] args) throws Exception { BatchOperator data = (BatchOperator)getData(true); MultiStringIndexer stringindexer = new MultiStringIndexer() .setSelectedCols("a", "b") .setOutputCols("a_indexed", "b_indexed") .setStringOrderType("frequency_asc"); stringindexer.fit(data).transform(data).print(); } } ``` 輸出如下: ```java a|b|a_indexed|b_indexed -|-|---------|--------- football|can|2|2 football|hhh|2|0 football|zzz|2|1 basketball|zzz|1|1 basketball|can|1|2 tennis|can|0|2 ``` 轉換成表格看的更清楚。 | a | b | a_indexed | b_indexed | | ---------- | ---- | --------- | --------- | | football | can | 2 | 2 | | football | hhh | 2 | 0 | | football | zzz | 2 | 1 | | basketball | zzz | 1 | 1 | | basketball | can | 1 | 2 | | tennis | can | 0 | 2 | ## 0x03 總體邏輯 我們先給出一個流程圖 ![](https://img2020.cnblogs.com/blog/1850883/202008/1850883-20200815074945552-187376175.png) 老套路,我們從 MultiStringIndexerTrainBatchOp.linkFrom開始挖掘。 ```java @Override public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator... inputs) { BatchOperator in = checkAndGetFirst(inputs); // 示例中有 .setSelectedCols("a", "b"),這裡是取出具體列名字 final String[] selectedColNames = getSelectedCols(); // 獲取列的型別 final String[] selectedColSqlType = new String[selectedColNames.length]; for (int i = 0; i < selectedColNames.length; i++) { selectedColSqlType[i] = FlinkTypeConverter.getTypeString( TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i])); } // runtime列印資料 selectedColNames = {String[2]@2536} 0 = "a" 1 = "b" selectedColSqlType = {String[2]@2537} 0 = "VARCHAR" 1 = "VARCHAR" // 獲取選取列對應的資料