1. 程式人生 > >Alink漫談(七) : 如何劃分訓練資料集和測試資料集

Alink漫談(七) : 如何劃分訓練資料集和測試資料集

# Alink漫談(七) : 如何劃分訓練資料集和測試資料集 [TOC] ## 0x00 摘要 Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將為大家展現Alink如何劃分訓練資料集和測試資料集。 ## 0x01 訓練資料集和測試資料集 **兩分法** 一般做預測分析時,會將資料分為兩大部分。一部分是訓練資料,用於構建模型,一部分是測試資料,用於檢驗模型。 **三分法** 但有時候模型的構建過程中也需要檢驗模型/輔助模型構建,這時會將訓練資料再分為兩個部分:1)訓練資料;2)驗證資料(Validation Data)。所以這種情況下會把資料分為三部分。 - 訓練資料(Train Data):用於模型構建。 - 驗證資料(Validation Data):可選,用於輔助模型構建,可以重複使用。 - 測試資料(Test Data):用於檢測模型構建,此資料只在模型檢驗時使用,用於評估模型的準確率。絕對不允許用於模型構建過程,否則會導致過渡擬合。 Training set是用來訓練模型或確定模型引數的,如ANN中權值等; Validation set是用來做模型選擇(model selection),即做模型的最終優化及確定,如ANN的結構; Test set則純粹是為了測試已經訓練好的模型的推廣能力。當然test set並不能保證模型的正確性,他只是說相似的資料用此模型會得出相似的結果。 **實際應用** 實際應用中,一般只將資料集分成兩類,即training set 和test set,大多數文章並不涉及validation set。我們這裡也不涉及。大家常用的sklearn的train_test_split函式就是將矩陣隨機劃分為訓練子集和測試子集,並返回劃分好的訓練集測試集樣本和訓練集測試集標籤。 ## 0x02 Alink示例程式碼 首先我們給出示例程式碼,然後會深入剖析: ```java public class SplitExample { public static void main(String[] args) throws Exception { String url = "iris.csv"; String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string"; //這裡是批處理 BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema); SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8); spliter.linkFrom(data); BatchOperator trainData = spliter; BatchOperator testData = spliter.getSideOutput(0); // 這裡是流處理 CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema); SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4); spliterS.linkFrom(dataS); StreamOperator train_data = spliterS; StreamOperator test_data = spliterS.getSideOutput(0); } } ``` ## 0x03 批處理 SplitBatchOp是分割批處理的主要類,具體構建DAG的工作是在其linkFrom完成的。 總體思路比較簡單: 1. 假定有一個取樣比例 fraction 2. 將資料集分割槽,平行計算每個分割槽上的記錄數 3. 把每個分割槽上的記錄數累積,得到所有記錄總數 totCount 4. 從上而下計算出一個取樣總數:`numTarget = totCount * fraction` 5. 因為具體選擇元素是在每個分割槽上做的,所以在每個分割槽上,分別計算出來這個分割槽應該取樣的記錄數,比如第n個分割槽上應取樣記錄數:`task_n_count * fraction` 6. 把這些分割槽 "應該取樣的記錄數" 累積,得出來從下而上計算出的取樣總數: `totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction` 7. numTarget 和 totSelect 可能不相等,所以隨機決定把多出來的 `numTarget - totSelect` 加入到某一個task中。 8. 在每個task上取樣得到具體的記錄。 ### 3.1 得到記錄數 如果要分割資料,首先必須知道資料集的記錄數。比如這個DataSet的記錄是1萬個?還是十萬個?因為資料集可能會很大,所以這一步操作也使用了並行處理,即把資料分割槽,然後通過mapPartition操作得到每一個分割槽上元素的數目。 ```java