1. 程式人生 > >JAVA實現K-means聚類

JAVA實現K-means聚類

個人部落格站已經上線了,網址 www.llwjy.com ~歡迎各位吐槽~

-------------------------------------------------------------------------------------------------

      在開始之前先打一個小小的廣告,自己建立一個QQ群:321903218,點選連結加入群【Lucene案例開發】,主要用於交流如何使用Lucene來建立站內搜尋後臺,同時還會不定期的在群內開相關的公開課,感興趣的童鞋可以加入交流。

      在上一篇部落格中已經介紹了KNN分類演算法,這篇部落格將重點介紹下K-means聚類演算法。K-means演算法是比較經典的聚類演算法,演算法的基本思想是選取K個點(隨機)作為中心進行聚類,然後對聚類的結果計算該類的質心,通過迭代的方法不斷更新質心,直到質心不變或稍微移動為止,則最後的聚類結果就是最後的聚類結果。下面首先介紹下K-means具體的演算法步驟。

K-means演算法

      在前面已經大概的介紹了下K-means,下面就介紹下具體的演算法描述:

1)選取K個點作為初始質心;

2)對每個樣本分別計算到K個質心的相似度或距離,將該樣本劃分到相似度最高或距離最短的質心所在類;

3)對該輪聚類結果,計算每一個類別的質心,新的質心作為下一輪的質心;

4)判斷演算法是否滿足終止條件,滿足終止條件結束,否則繼續第2、3、4步。

      在介紹演算法之前,我們首先看下K-means演算法聚類平面200,000個點聚成34個類別的結果(如下圖)

img

演算法實現

      K-means聚類演算法整體思想比較簡單,下面 就分步介紹如何用JAVA來實現K-means演算法。

一、K-means演算法基礎屬性

      在K-means演算法中,有幾個重要的指標,比如K值、最大迭代次數等,對於這些指標,我們統一把它們設定為類的屬性,如下:

private List<T> dataArray;//待分類的原始值
private int K = 3;//將要分成的類別個數
private int maxClusterTimes = 500;//最大迭代次數
private List<List<T>> clusterList;//聚類的結果
private List<T> clusteringCenterT;//質心
二、初始質心的選擇

      K-means聚類演算法的結果很大程度收到初始質心的選取,這了為了保證有充分的隨機性,對於初始質心的選擇這裡採用完全隨機的方法,先把待分類的資料隨機打亂,然後把前K個樣本作為初始質心(通過多次迭代,會減少初始質心的影響)。

List<T> centerT = new ArrayList<T>(size);
//對資料進行打亂
Collections.shuffle(dataArray);
for (int i = 0; i < size; i++) {
	centerT.add(dataArray.get(i));
}
三、一輪聚類

      在K-means演算法中,大部分的時間都在做一輪一輪的聚類,具體功能也很簡單,就是對每一個樣本分別計算和所有質心的相似度或距離,找到與該樣本最相似的質心或者距離最近的質心,然後把該樣本劃分到該類中,具體邏輯介紹參照程式碼中的註釋。

private void clustering(List<T> preCenter, int times) {
	if (preCenter == null || preCenter.size() < 2) {
		return;
	}
	//打亂質心的順序
	Collections.shuffle(preCenter);
	List<List<T>> clusterList =  getListT(preCenter.size());
	for (T o1 : this.dataArray) {
		//尋找最相似的質心
		int max = 0;
		double maxScore = similarScore(o1, preCenter.get(0));
		for (int i = 1; i < preCenter.size(); i++) {
			if (maxScore < similarScore(o1, preCenter.get(i))) {
				maxScore = similarScore(o1, preCenter.get(i));
				max = i;
			}
		}
		clusterList.get(max).add(o1);
	}
	//計算本次聚類結果每個類別的質心
	List<T> nowCenter = new ArrayList<T> ();
	for (List<T> list : clusterList) {
		nowCenter.add(getCenterT(list));
	}
	//是否達到最大迭代次數
	if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
		this.clusterList = clusterList;
		return;
	}
	this.clusteringCenterT = nowCenter;
	//判斷質心是否發生移動,如果沒有移動,結束本次聚類,否則進行下一輪
	if (isCenterChange(preCenter, nowCenter)) {
		clear(clusterList);
		clustering(nowCenter, times + 1);
	} else {
		this.clusterList = clusterList;
	}
}
四、質心是否移動

      在第三步中,提到了一個重要的步驟:每輪聚類結束後,都要重新計算質心,並且計算質心是否發生移動。對於新質心的計算、樣本之間的相似度和判斷兩個樣本是否相等這幾個功能由於並不知道樣本的具體資料型別,因此把他們定義成抽象方法,供子類來實現。下面就重點介紹如何判斷質心是否發生移動。

private boolean isCenterChange(List<T> preT, List<T> nowT) {
	if (preT == null || nowT == null) {
		return false;
	}
	for (T t1 : preT) {
		boolean bol = true;
		for (T t2 : nowT) {
			if (equals(t1, t2)) {//t1在t2中有相等的,認為該質心未移動
				bol = false;
				break;
			}
		}
		//有一個質心發生移動,認為需要進行下一次計算
		if (bol) {
			return bol;
		}
	}
	return false;
}

      從上述程式碼可以看到,演算法的思想就是對於前後兩個質心陣列分別前一組的質心是否在後一個質心組中出現,有一個沒有出現,就認為質心發生了變動。

完整程式碼

      上面四步已經完整的介紹了K-means演算法的具體演算法思想,下面就看下完整的程式碼實現。

 /**  
 *@Description:  K-means聚類
 */ 
package com.lulei.datamining.knn;  

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
  
public abstract class KMeansClustering <T>{
	private List<T> dataArray;//待分類的原始值
	private int K = 3;//將要分成的類別個數
	private int maxClusterTimes = 500;//最大迭代次數
	private List<List<T>> clusterList;//聚類的結果
	private List<T> clusteringCenterT;//質心
	
	public int getK() {
		return K;
	}
	public void setK(int K) {
		if (K < 1) {
			throw new IllegalArgumentException("K must greater than 0");
		}
		this.K = K;
	}
	public int getMaxClusterTimes() {
		return maxClusterTimes;
	}
	public void setMaxClusterTimes(int maxClusterTimes) {
		if (maxClusterTimes < 10) {
			throw new IllegalArgumentException("maxClusterTimes must greater than 10");
		}
		this.maxClusterTimes = maxClusterTimes;
	}
	public List<T> getClusteringCenterT() {
		return clusteringCenterT;
	}
	/**
	 * @return
	 * @Author:lulei  
	 * @Description: 對資料進行聚類
	 */
	public List<List<T>> clustering() {
		if (dataArray == null) {
			return null;
		}
		//初始K個點為陣列中的前K個點
		int size = K > dataArray.size() ? dataArray.size() : K;
		List<T> centerT = new ArrayList<T>(size);
		//對資料進行打亂
		Collections.shuffle(dataArray);
		for (int i = 0; i < size; i++) {
			centerT.add(dataArray.get(i));
		}
		clustering(centerT, 0);
		return clusterList;
	}
	
	/**
	 * @param preCenter
	 * @param times
	 * @Author:lulei  
	 * @Description: 一輪聚類
	 */
	private void clustering(List<T> preCenter, int times) {
		if (preCenter == null || preCenter.size() < 2) {
			return;
		}
		//打亂質心的順序
		Collections.shuffle(preCenter);
		List<List<T>> clusterList =  getListT(preCenter.size());
		for (T o1 : this.dataArray) {
			//尋找最相似的質心
			int max = 0;
			double maxScore = similarScore(o1, preCenter.get(0));
			for (int i = 1; i < preCenter.size(); i++) {
				if (maxScore < similarScore(o1, preCenter.get(i))) {
					maxScore = similarScore(o1, preCenter.get(i));
					max = i;
				}
			}
			clusterList.get(max).add(o1);
		}
		//計算本次聚類結果每個類別的質心
		List<T> nowCenter = new ArrayList<T> ();
		for (List<T> list : clusterList) {
			nowCenter.add(getCenterT(list));
		}
		//是否達到最大迭代次數
		if (times >= this.maxClusterTimes || preCenter.size() < this.K) {
			this.clusterList = clusterList;
			return;
		}
		this.clusteringCenterT = nowCenter;
		//判斷質心是否發生移動,如果沒有移動,結束本次聚類,否則進行下一輪
		if (isCenterChange(preCenter, nowCenter)) {
			clear(clusterList);
			clustering(nowCenter, times + 1);
		} else {
			this.clusterList = clusterList;
		}
	}
	
	/**
	 * @param size
	 * @return
	 * @Author:lulei  
	 * @Description: 初始化一個聚類結果
	 */
	private List<List<T>> getListT(int size) {
		List<List<T>> list = new ArrayList<List<T>>(size);
		for (int i = 0; i < size; i++) {
			list.add(new ArrayList<T>());
		}
		return list;
	}
	
	/**
	 * @param lists
	 * @Author:lulei  
	 * @Description: 清空無用陣列
	 */
	private void clear(List<List<T>> lists) {
		for (List<T> list : lists) {
			list.clear();
		}
		lists.clear();
	}
	
	/**
	 * @param value
	 * @Author:lulei  
	 * @Description: 向模型中新增記錄
	 */
	public void addRecord(T value) {
		if (dataArray == null) {
			dataArray = new ArrayList<T>();
		}
		dataArray.add(value);
	}
	
	/**
	 * @param preT
	 * @param nowT
	 * @return
	 * @Author:lulei  
	 * @Description: 判斷質心是否發生移動
	 */
	private boolean isCenterChange(List<T> preT, List<T> nowT) {
		if (preT == null || nowT == null) {
			return false;
		}
		for (T t1 : preT) {
			boolean bol = true;
			for (T t2 : nowT) {
				if (equals(t1, t2)) {//t1在t2中有相等的,認為該質心未移動
					bol = false;
					break;
				}
			}
			//有一個質心發生移動,認為需要進行下一次計算
			if (bol) {
				return bol;
			}
		}
		return false;
	}
	
	/**
	 * @param o1
	 * @param o2
	 * @return
	 * @Author:lulei  
	 * @Description: o1 o2之間的相似度
	 */
	public abstract double similarScore(T o1, T o2);
	
	/**
	 * @param o1
	 * @param o2
	 * @return
	 * @Author:lulei  
	 * @Description: 判斷o1 o2是否相等
	 */
	public abstract boolean equals(T o1, T o2);
	
	/**
	 * @param list
	 * @return
	 * @Author:lulei  
	 * @Description: 求一組資料的質心
	 */
	public abstract T getCenterT(List<T> list);
}

二維數聚類實現

      在演算法描述中,介紹了一個200,000個點聚成34個類別的效果圖,下面就針對二維座標資料實現其具體子類。

一、相似度

      對於二維座標的相似度,這裡我們採取兩點間聚類的相反數,具體實現如下:

	@Override
	public double similarScore(XYbean o1, XYbean o2) {
		double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));
		return distance * -1;
	}
二、樣本/質心是否相等

      判斷樣本/質心是否相等只需要判斷兩點的座標是否相等即可,具體實現如下:

	@Override
	public boolean equals(XYbean o1, XYbean o2) {
		return o1.getX() == o2.getX() && o1.getY() == o2.getY();
	}
三、獲取一個分類下的新質心

      對於二維座標資料,可以使用所有點的重心作為分類的質心,具體如下:

	@Override
	public XYbean getCenterT(List<XYbean> list) {
		int x = 0;
		int y = 0;
		try {
			for (XYbean xy : list) {
				x += xy.getX();
				y += xy.getY();
			}
			x = x / list.size();
			y = y / list.size();
		} catch(Exception e) {
			
		}
		return new XYbean(x, y);
	}
四、main方法

      對於具體二維座標的原始碼這裡就不再貼出來,就是實現前面介紹的抽象類,並實現其中的3個抽象方法,下面我們就隨機產生200,000個點,然後聚成34個類別,具體程式碼如下:

	public static void main(String[] args) {
		
		int width = 600;
		int height = 400;
		int K = 34;
		XYCluster xyCluster = new XYCluster();
		for (int i = 0; i < 200000; i++) {
			int x = (int)(Math.random() * width) + 1;
			int y = (int)(Math.random() * height) + 1;
			xyCluster.addRecord(new XYbean(x, y));
		}
		xyCluster.setK(K);
		long a = System.currentTimeMillis();
		List<List<XYbean>> cresult = xyCluster.clustering();
		List<XYbean> center = xyCluster.getClusteringCenterT();
		System.out.println(JsonUtil.parseJson(center));
		long b = System.currentTimeMillis();
		System.out.println("耗時:" + (b - a) + "ms");
		new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);
	}
      對於這隨機產生的200,000個點聚成34類,總耗時5485ms。(計算機配置:i5 + 8G記憶體)

-------------------------------------------------------------------------------------------------
小福利
-------------------------------------------------------------------------------------------------
      個人在極客學院上《Lucene案例開發》課程已經上線了,歡迎大家吐槽~