在Ignite中使用k-均值聚類演算法
在本系列前面的文章中,簡單介紹了一下Ignite的k-最近鄰(k-NN)分類演算法,下面會嘗試另一個機器學習演算法,即使用泰坦尼克資料集介紹k-均值聚類演算法。正好,Kaggle提供了CSV格式的資料集,而要分析的是兩個分類:即乘客是否倖存。
為了將資料轉換為Ignite支援的格式,前期需要做一些清理和格式化的工作,CSV檔案中包含若干個列,如下:
- 乘客Id
- 倖存(0:否,1:是)
- 船票席別(1:一,2:二,3:三)
- 乘客姓名
- 性別
- 年齡
- 泰坦尼克號上的兄弟/姐妹數
- 泰坦尼克號上的父母/子女數
- 船票號碼
- 票價
- 客艙號碼
- 登船港口(C=瑟堡,Q=皇后鎮,S=南安普頓)
因此首先要做的是,刪除任何和特定乘客有關的、和生存無關的列,如下:
- 乘客Id
- 乘客姓名
- 船票號碼
- 客艙號碼
接下來會刪除任何資料有缺失的行,比如年齡或者登船港口,可以對這些值進行歸類,但是為了進行初步的分析,會刪除缺失的值。
最後會將部分欄位轉換為數值型別,比如性別會被轉換為:
- 0:女
- 1:男
登船港口會被轉換為:
- 0:Q(皇后鎮)
- 1:C(瑟堡)
- 2:S(南安普頓)
最終的資料集由如下的列組成:
- 船票席別
- 性別
- 年齡
- 泰坦尼克號上的兄弟/姐妹數
- 泰坦尼克號上的父母/子女數
- 票價
- 登船港口
- 倖存
可以看到,倖存列已被移到最後。
下一步會將資料拆分為訓練資料(80%)和測試資料(20%),和前文一樣,還是使用Scikit-learn來執行這個拆分任務。
準備好訓練和測試資料後,就可以編寫應用了,本文的演算法是:
- 讀取訓練資料和測試資料;
- 在Ignite中儲存訓練資料和測試資料;
- 使用訓練資料擬合k-均值聚類模型;
- 將模型應用於測試資料;
- 確定含混矩陣和模型的準確性。
讀取訓練資料和測試資料
通過下面的程式碼,可以從CSV檔案中讀取資料:
private static void loadData(String fileName, IgniteCache<Integer, TitanicObservation> cache) throws FileNotFoundException { Scanner scanner = new Scanner(new File(fileName)); int cnt = 0; while (scanner.hasNextLine()) { String row = scanner.nextLine(); String[] cells = row.split(","); double[] features = new double[cells.length - 1]; for (int i = 0; i < cells.length - 1; i++) features[i] = Double.valueOf(cells[i]); double survivedClass = Double.valueOf(cells[cells.length - 1]); cache.put(cnt++, new TitanicObservation(features, survivedClass)); } }
該程式碼簡單地一行行的讀取資料,然後對於每一行,使用CSV的分隔符拆分出欄位,每個欄位之後將轉換成double型別並且存入Ignite。
將訓練資料和測試資料存入Ignite
前面的程式碼將資料存入Ignite,要使用這個程式碼,首先要建立Ignite儲存,如下:
IgniteCache<Integer, TitanicObservation> trainData = getCache(ignite, "TITANIC_TRAIN");
IgniteCache<Integer, TitanicObservation> testData = getCache(ignite, "TITANIC_TEST");
loadData("src/main/resources/titanic-train.csv", trainData);
loadData("src/main/resources/titanic-test.csv", testData);
getCache()
的實現如下:
private static IgniteCache<Integer, TitanicObservation> getCache(Ignite ignite, String cacheName) {
CacheConfiguration<Integer, TitanicObservation> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName(cacheName);
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, TitanicObservation> cache = ignite.createCache(cacheConfiguration);
return cache;
}
使用訓練資料擬合k-NN分類模型
資料儲存之後,可以像下面這樣建立訓練器:
KMeansTrainer trainer = new KMeansTrainer()
.withK(2)
.withDistance(new EuclideanDistance())
.withSeed(123L);
這裡k的值配置為2,表示有2個簇(倖存和未倖存),對於距離測量,可以有多個選擇,比如歐幾里得、海明或曼哈頓,在本例中會使用歐幾里得,另外,種子值賦值為123。
然後擬合訓練資料,如下:
KMeansModel mdl = trainer.fit(
ignite,
trainData,
(k, v) -> v.getFeatures(),
// Feature extractor.
(k, v) -> v.getSurvivedClass()
// Label extractor.
);
Ignite將資料儲存為鍵-值(K-V)格式,因此上面的程式碼使用了值部分,目標值是Survived
類,特徵在其它列中。
將模型應用於測試資料
下一步,就可以用訓練好的分類模型測試測試資料了,可以這樣做:
int amountOfErrors = 0;
int totalAmount = 0;
int[][] confusionMtx = {{0, 0}, {0, 0}};
try (QueryCursor<Cache.Entry<Integer, TitanicObservation>> cursor = testData.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, TitanicObservation> testEntry : cursor) {
TitanicObservation observation = testEntry.getValue();
double groundTruth = observation.getSurvivedClass();
double prediction = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));
totalAmount++;
if ((int) groundTruth != (int) prediction)
amountOfErrors++;
int idx1 = (int) prediction;
int idx2 = (int) groundTruth;
confusionMtx[idx1][idx2]++;
System.out.printf(">>> | %.4f\t | %.0f\t\t\t|\n", prediction, groundTruth);
}
}
確定含混矩陣和模型的準確性
下面,就可以通過對測試資料中的真實分類和模型進行的分類進行對比,來確認模型的真確性。
程式碼執行之後,輸出如下:
>>> Absolute amount of errors 56
>>> Accuracy 0.6084
>>> Precision 0.5865
>>> Recall 0.9873
>>> Confusion matrix is [[78, 55], [1, 9]]
這個初步的結果可不可以改進?可以嘗試的是對特徵的衡量,在Ignite和Scikit-learn中,可以使用MinMaxScaler()
,然後會給出如下的輸出:
>>> Absolute amount of errors 29
>>> Accuracy 0.7972
>>> Precision 0.8205
>>> Recall 0.8101
>>> Confusion matrix is [[64, 14], [15, 50]]
作為進一步分析的一部分,還應該研究倖存與否和年齡和性別之間的關係。
總結
通常來說,k-均值聚類並不適合監督學習任務,但是如果分類很容易,這個方法還是有效的。對於本例來說,關注的就是是否倖存。