1. 程式人生 > >在Ignite中使用k-均值聚類演算法

在Ignite中使用k-均值聚類演算法

在本系列前面的文章中,簡單介紹了一下Ignite的k-最近鄰(k-NN)分類演算法,下面會嘗試另一個機器學習演算法,即使用泰坦尼克資料集介紹k-均值聚類演算法。正好,Kaggle提供了CSV格式的資料集,而要分析的是兩個分類:即乘客是否倖存。

為了將資料轉換為Ignite支援的格式,前期需要做一些清理和格式化的工作,CSV檔案中包含若干個列,如下:

  • 乘客Id
  • 倖存(0:否,1:是)
  • 船票席別(1:一,2:二,3:三)
  • 乘客姓名
  • 性別
  • 年齡
  • 泰坦尼克號上的兄弟/姐妹數
  • 泰坦尼克號上的父母/子女數
  • 船票號碼
  • 票價
  • 客艙號碼
  • 登船港口(C=瑟堡,Q=皇后鎮,S=南安普頓)

因此首先要做的是,刪除任何和特定乘客有關的、和生存無關的列,如下:

  • 乘客Id
  • 乘客姓名
  • 船票號碼
  • 客艙號碼

接下來會刪除任何資料有缺失的行,比如年齡或者登船港口,可以對這些值進行歸類,但是為了進行初步的分析,會刪除缺失的值。

最後會將部分欄位轉換為數值型別,比如性別會被轉換為:

  • 0:女
  • 1:男

登船港口會被轉換為:

  • 0:Q(皇后鎮)
  • 1:C(瑟堡)
  • 2:S(南安普頓)

最終的資料集由如下的列組成:

  • 船票席別
  • 性別
  • 年齡
  • 泰坦尼克號上的兄弟/姐妹數
  • 泰坦尼克號上的父母/子女數
  • 票價
  • 登船港口
  • 倖存

可以看到,倖存列已被移到最後。

下一步會將資料拆分為訓練資料(80%)和測試資料(20%),和前文一樣,還是使用Scikit-learn來執行這個拆分任務。

準備好訓練和測試資料後,就可以編寫應用了,本文的演算法是:

  1. 讀取訓練資料和測試資料;
  2. 在Ignite中儲存訓練資料和測試資料;
  3. 使用訓練資料擬合k-均值聚類模型;
  4. 將模型應用於測試資料;
  5. 確定含混矩陣和模型的準確性。

讀取訓練資料和測試資料

通過下面的程式碼,可以從CSV檔案中讀取資料:

private static void loadData(String fileName, IgniteCache<Integer, TitanicObservation> cache)
        throws FileNotFoundException {

   Scanner scanner = new Scanner(new File(fileName));

   int cnt = 0;
   while (scanner.hasNextLine()) {
      String row = scanner.nextLine();
      String[] cells = row.split(",");
      double[] features = new double[cells.length - 1];

      for (int i = 0; i < cells.length - 1; i++)
         features[i] = Double.valueOf(cells[i]);
      double survivedClass = Double.valueOf(cells[cells.length - 1]);

      cache.put(cnt++, new TitanicObservation(features, survivedClass));
   }
}

該程式碼簡單地一行行的讀取資料,然後對於每一行,使用CSV的分隔符拆分出欄位,每個欄位之後將轉換成double型別並且存入Ignite。

將訓練資料和測試資料存入Ignite

前面的程式碼將資料存入Ignite,要使用這個程式碼,首先要建立Ignite儲存,如下:

IgniteCache<Integer, TitanicObservation> trainData = getCache(ignite, "TITANIC_TRAIN");
IgniteCache<Integer, TitanicObservation> testData = getCache(ignite, "TITANIC_TEST");
loadData("src/main/resources/titanic-train.csv", trainData);
loadData("src/main/resources/titanic-test.csv", testData);

getCache()的實現如下:

private static IgniteCache<Integer, TitanicObservation> getCache(Ignite ignite, String cacheName) {

   CacheConfiguration<Integer, TitanicObservation> cacheConfiguration = new CacheConfiguration<>();
   cacheConfiguration.setName(cacheName);
   cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));

   IgniteCache<Integer, TitanicObservation> cache = ignite.createCache(cacheConfiguration);

   return cache;
}

使用訓練資料擬合k-NN分類模型

資料儲存之後,可以像下面這樣建立訓練器:

KMeansTrainer trainer = new KMeansTrainer()
        .withK(2)
        .withDistance(new EuclideanDistance())
        .withSeed(123L);

這裡k的值配置為2,表示有2個簇(倖存和未倖存),對於距離測量,可以有多個選擇,比如歐幾里得、海明或曼哈頓,在本例中會使用歐幾里得,另外,種子值賦值為123。

然後擬合訓練資料,如下:

KMeansModel mdl = trainer.fit(
        ignite,
        trainData,
        (k, v) -> v.getFeatures(),        
// Feature extractor.

        (k, v) -> v.getSurvivedClass()    
// Label extractor.

);

Ignite將資料儲存為鍵-值(K-V)格式,因此上面的程式碼使用了值部分,目標值是Survived類,特徵在其它列中。

將模型應用於測試資料

下一步,就可以用訓練好的分類模型測試測試資料了,可以這樣做:

int amountOfErrors = 0;
int totalAmount = 0;
int[][] confusionMtx = {{0, 0}, {0, 0}};

try (QueryCursor<Cache.Entry<Integer, TitanicObservation>> cursor = testData.query(new ScanQuery<>())) {
   for (Cache.Entry<Integer, TitanicObservation> testEntry : cursor) {
      TitanicObservation observation = testEntry.getValue();

      double groundTruth = observation.getSurvivedClass();
      double prediction = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

      totalAmount++;
      if ((int) groundTruth != (int) prediction)
         amountOfErrors++;

      int idx1 = (int) prediction;
      int idx2 = (int) groundTruth;

      confusionMtx[idx1][idx2]++;

      System.out.printf(">>> | %.4f\t | %.0f\t\t\t|\n", prediction, groundTruth);
   }
}

確定含混矩陣和模型的準確性

下面,就可以通過對測試資料中的真實分類和模型進行的分類進行對比,來確認模型的真確性。

程式碼執行之後,輸出如下:

>>> Absolute amount of errors 56

>>> Accuracy 0.6084

>>> Precision 0.5865

>>> Recall 0.9873

>>> Confusion matrix is [[78, 55], [1, 9]]

這個初步的結果可不可以改進?可以嘗試的是對特徵的衡量,在Ignite和Scikit-learn中,可以使用MinMaxScaler(),然後會給出如下的輸出:

>>> Absolute amount of errors 29

>>> Accuracy 0.7972

>>> Precision 0.8205

>>> Recall 0.8101

>>> Confusion matrix is [[64, 14], [15, 50]]

作為進一步分析的一部分,還應該研究倖存與否和年齡和性別之間的關係。

總結

通常來說,k-均值聚類並不適合監督學習任務,但是如果分類很容易,這個方法還是有效的。對於本例來說,關注的就是是否倖存。