1. 程式人生 > >Kmeans演算法詳解及實現

Kmeans演算法詳解及實現

今天我們介紹資料探勘領域最基本的一個演算法,Kmeans演算法,並進行演算法的講解及實現。

我們知道聚類問題屬於經典問題,而對於聚類演算法,也是有很多不同的種類,kmeans就是其中一種最基本的聚類演算法。

它的主要演算法流程是:

(1)隨機的取k個點作為k個初始質心;

(2)計算其他點到這個k個質心的距離;

(3)如果某個點p離第n個質心的距離更近,則該點屬於cluster n,並對其打標籤,標註point p.label=n,其中n<=k;

(4)計算同一cluster中,也就是相同label的點向量的平均值,作為新的質心;

(5)迭代至所有質心都不變化為止,即演算法結束。

當然實現的方法有很多,比如在選擇初始質心時,可以隨機選擇k個,也可以隨機選擇k個離得最遠的點等等,方法不盡相同。

對於k值,必須提前知道,這也是kmeans演算法的一個缺點。當然對於k值,我們可有有很多種方法進行估計。本文中,我們採用平均直徑法來進行k的估計。

也就是說,首先視所有點為一個大的整體cluster,計算所有點間距離的平均值作為該cluster的平均直徑。選擇初始質心的時候,先選擇最遠的兩個點,接下來從這最兩個點開始,與這最兩個點距離都很遠的點(遠的程度為,該點到之前選擇的最遠的兩個點的距離都大於整體cluster的平均直徑)可視為新發現的質心,否則不視之為質心。

這樣,我們就能估計出k的值,並且得到k個初始質心,接著,我們便根據上述演算法流程繼續進行迭代,直到所有質心都不變化,從而成功實現演算法。

本文實現程式碼為最基礎的實現方式,如果資料多維,可能會需要做資料預處理,比如歸一化,並且修改程式碼相關函式即可。

下面附上程式碼,關鍵處已有註釋,如有問題請留言:

附上一組測試資料,執行前請將資料copy至c:\\kmeans.txt

下面資料的意義為點座標:

1,1
2,1
1,2
2,2
6,1
6,2
7,1
7,2
1,5
1,6
2,5
2,6
6,5
6,6
7,5
7,6

得到輸出結果為:

There are 4 clusters!
1.0 1.0 belongs to cluster 1
2.0 1.0 belongs to cluster 1
1.0 2.0 belongs to cluster 1
2.0 2.0 belongs to cluster 1
6.0 1.0 belongs to cluster 3
6.0 2.0 belongs to cluster 3
7.0 1.0 belongs to cluster 3
7.0 2.0 belongs to cluster 3
1.0 5.0 belongs to cluster 4
1.0 6.0 belongs to cluster 4
2.0 5.0 belongs to cluster 4
2.0 6.0 belongs to cluster 4
6.0 5.0 belongs to cluster 2
6.0 6.0 belongs to cluster 2
7.0 5.0 belongs to cluster 2
7.0 6.0 belongs to cluster 2

這裡附上適用於n維資料集的kmeans實現程式碼,其實很簡單,還是那句話,這個程式碼對不同的資料集效果可能不同,關鍵因素有很多,主要在於對k的估計,k個初始質心的選擇,以及資料預處理,節點間採用的距離度量方式等等。當然,kmeans本身就是最naive的方法,如果想得到更好的結果,可以採用其他的方法,比如神經網路等等。

注:此程式碼不包括資料預處理,可根據資料集特性選擇相應的預處理方式,比如歸一化等等。預處理後的資料應用下面的程式碼效果一般較為理想。

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Comparator;

public class Kmeans {
	class Node {
		int label;// label用來記錄點屬於第幾個cluster
		double[] attributes;

		public Node() {
			attributes = new double[100];
		}
	}

	class NodeComparator {
		Node nodeOne;
		Node nodeTwo;
		double distance;

		public void compute() {
			double val = 0;
			for (int i = 0; i < dimension; ++i) {
				val += (this.nodeOne.attributes[i] - this.nodeTwo.attributes[i])
						* (this.nodeOne.attributes[i] - this.nodeTwo.attributes[i]);
			}
			this.distance = val;
		}
	}

	ArrayList<Node> arraylist;
	ArrayList<Node> centroidList;
	double averageDis;
	int dimension;
	Queue<NodeComparator> FsQueue = new PriorityQueue<NodeComparator>(150,// 用來排序任意兩點之間的距離,從大到小排
			new Comparator<NodeComparator>() {
				public int compare(NodeComparator one, NodeComparator two) {
					if (one.distance < two.distance)
						return 1;
					else if (one.distance > two.distance)
						return -1;
					else
						return 0;
				}
			});

	public Kmeans(String path) {// 建構函式讀入資料
		try {
			BufferedReader br = new BufferedReader(new FileReader(path));
			String str;
			String[] strArray;
			arraylist = new ArrayList<Node>();
			while ((str = br.readLine()) != null) {
				strArray = str.split(",");
				dimension = strArray.length;
				Node node = new Node();
				for (int i = 0; i < dimension; ++i) {
					node.attributes[i] = Double.parseDouble(strArray[i]);
				}
				arraylist.add(node);
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	public void computeTheK() {
		int cntTuple = 0;
		for (int i = 0; i < arraylist.size() - 1; ++i) {
			for (int j = i + 1; j < arraylist.size(); ++j) {
				NodeComparator nodecomp = new NodeComparator();
				nodecomp.nodeOne = new Node();
				nodecomp.nodeTwo = new Node();
				for (int k = 0; k < dimension; ++k) {
					nodecomp.nodeOne.attributes[k] = arraylist.get(i).attributes[k];
					nodecomp.nodeTwo.attributes[k] = arraylist.get(j).attributes[k];
				}
				nodecomp.compute();
				averageDis += nodecomp.distance;
				FsQueue.add(nodecomp);
				cntTuple++;
			}
		}
		averageDis /= cntTuple;// 計算平均距離
		chooseCentroid(FsQueue);
	}

	public double getDistance(Node one, Node two) {// 計算兩點間的歐氏距離
		double val = 0;
		for (int i = 0; i < dimension; ++i) {
			val += (one.attributes[i] - two.attributes[i])
					* (one.attributes[i] - two.attributes[i]);
		}
		return val;
	}

	public void chooseCentroid(Queue<NodeComparator> queue) {
		centroidList = new ArrayList<Node>();
		boolean flag = false;
		while (!queue.isEmpty()) {
			boolean judgeOne = false;
			boolean judgeTwo = false;
			NodeComparator nc = FsQueue.poll();
			if (nc.distance < averageDis)
				break;// 如果接下來的元組,兩節點間距離小於平均距離,則不繼續迭代
			if (!flag) {
				centroidList.add(nc.nodeOne);// 先加入所有點中距離最遠的兩個點
				centroidList.add(nc.nodeTwo);
				flag = true;
			} else {// 之後從之前已加入的最遠的兩個點開始,找離這兩個點最遠的點,
					//如果距離大於所有點的平均距離,則認為找到了新的質心,否則不認定為質心
				for (int i = 0; i < centroidList.size(); ++i) {
					Node testnode = centroidList.get(i);
					if (centroidList.contains(nc.nodeOne)
							|| getDistance(testnode, nc.nodeOne) < averageDis) {
						judgeOne = true;
					}
					if (centroidList.contains(nc.nodeTwo)
							|| getDistance(testnode, nc.nodeTwo) < averageDis) {
						judgeTwo = true;
					}
				}
				if (!judgeOne) {
					centroidList.add(nc.nodeOne);
				}
				if (!judgeTwo) {
					centroidList.add(nc.nodeTwo);
				}
			}
		}
	}

	public void doIteration(ArrayList<Node> centroid) {

		int cnt = 1;
		int cntEnd = 0;
		int numLabel=centroid.size();
		while (true) {// 迭代,直到所有的質心都不變化為止
			boolean flag = false;
			for (int i = 0; i < arraylist.size(); ++i) {
				double dis = 0x7fffffff;
				cnt = 1;
				for (int j = 0; j < centroid.size(); ++j) {
					Node node = centroid.get(j);
					if (getDistance(arraylist.get(i), node) < dis) {
						dis = getDistance(arraylist.get(i), node);
						arraylist.get(i).label = cnt;
					}
					cnt++;
				}
			}
			int j = 0;
			numLabel-=1;
			while (j < numLabel) {
				int c = 0;
				Node node = new Node();
				for (int i = 0; i < arraylist.size(); ++i) {
					if (arraylist.get(i).label == j + 1) {
						for (int k = 0; k < dimension; ++k) {
							node.attributes[k] += arraylist.get(i).attributes[k];
						}
						c++;
					}
				}
				DecimalFormat df = new DecimalFormat("#.###");// 保留小數點後三位
				double[] attributelist = new double[100];
				for (int i = 0; i < dimension; ++i) {
					attributelist[i] = Double.parseDouble(df
							.format(node.attributes[i] / c));
					if (attributelist[i] != centroid.get(j).attributes[i]) {
						centroid.get(j).attributes[i] = attributelist[i];
						flag = true;
					}
				}
				if (!flag) {
					cntEnd++;
					if (cntEnd == numLabel) {// 若所有的質心都不變,則跳出迴圈
						break;
					}
				}
				j++;
			}
			if (cntEnd == numLabel) {// 若所有的質心都不變,則success
				System.out.println("do kmeans success");
				break;
			}
		}
	}

	public void getKmeansResults(String path) {
		try {
			PrintStream out = new PrintStream(path);
			computeTheK();
			doIteration(centroidList);
			out.println("There are " + centroidList.size() + " clusters!");
			for (int i = 0; i < arraylist.size(); ++i) {
				for (int j = 0; j < dimension; ++j) {
					out.print(arraylist.get(i).attributes[j] + " ");
				}
				out.println("belongs to cluster " + arraylist.get(i).label);
			}
			out.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	public static void main(String[] args) {
		Kmeans kmeans = new Kmeans("c:/kmeans.txt");
		kmeans.getKmeansResults("c:/kmeansResults.txt");

	}
}


相關推薦

Kmeans演算法實現

今天我們介紹資料探勘領域最基本的一個演算法,Kmeans演算法,並進行演算法的講解及實現。 我們知道聚類問題屬於經典問題,而對於聚類演算法,也是有很多不同的種類,kmeans就是其中一種最基本的聚類演算法。 它的主要演算法流程是: (1)隨機的取k個點作為k個初始質心;

堆排序演算法實現-----------c語言

堆排序原理:   堆排序指的是將大堆(小堆)堆頂(即下標為0)元素與堆的最後一個(即下標為hp->size - 1)元素交換,hp->size–,將其餘的元素再次調整成大堆(小堆),再次將堆頂(即下標為0)元素與堆的最後一個(即下標為hp->s

機器學習經典演算法Python實現--線性迴歸(Linear Regression)演算法

(一)認識迴歸 迴歸是統計學中最有力的工具之一。機器學習監督學習演算法分為分類演算法和迴歸演算法兩種,其實就是根據類別標籤分佈型別為離散型、連續性而定義的。顧名思義,分類演算法用於離散型分佈預測,如前

小白之KMP演算法python實現

在看子串匹配問題的時候,書上的關於KMP的演算法的介紹總是理解不了。看了一遍程式碼總是很快的忘掉,後來決定好好分解一下KMP演算法,算是給自己加深印象。 ------------------------- 分割線-------------------------------

迪克斯特拉演算法C++實現

演算法步驟如下: G={V,E} 1. 初始時令 S={V0},T=V-S={其餘頂點},T中頂點對應的距離值 若存在<V0,Vi>,d(V0,Vi)為<V0,Vi>弧上的權值

遺傳演算法Java實現

遺傳演算法的起源 ========== 20世紀60年代中期,美國密西根大學的John Holland提出了位串編碼技術,這種編碼既適合於變異又適合雜交操作,並且他強調將雜交作為主要的遺傳操作。遺傳演算法的通用編碼技術及簡單有效的遺傳操作為其廣泛的應用和成功

機器學習經典演算法Python實現--決策樹(Decision Tree)

(一)認識決策樹 1,決策樹分類原理 決策樹是通過一系列規則對資料進行分類的過程。它提供一種在什麼條件下會得到什麼值的類似規則的方法。決策樹分為分類樹和迴歸樹兩種,分類樹對離散變數做決策樹,迴歸樹對連續變數做決策樹。 近來的調查表明決策樹也是最經常使用的資料探勘演算法,它

常見9大排序演算法python3實現

穩定:如果a原本在b前面,而a=b,排序之後a仍然在b的前面; 不穩定:如果a原本在b的前面,而a=b,排序之後a可能會出現在b的後面; 內排序:所有排序操作都在記憶體中完成; 外排序:由於資料太大,因此把資料放在磁碟中,而排序通過磁碟和記憶體的資料傳輸才能進行;

機器學習經典演算法Python實現--K近鄰(KNN)演算法

轉載http://blog.csdn.net/suipingsp/article/details/41964713 (一)KNN依然是一種監督學習演算法 KNN(K Nearest Neighbors,K近鄰 )演算法是機器學習所有演算法中理論最簡單,最好理解的。KNN

RSA演算法C語言實現

1、什麼是RSA RSA公鑰加密演算法是1977年由羅納德·李維斯特(Ron Rivest)、阿迪·薩莫爾(Adi Shamir)和倫納德·阿德曼(Leonard Adleman)一起提出的。1987年首次公佈,當時他們三人都在麻省理工學院工作。RSA就是他們

結點對最短路徑之Floyd演算法原理實現

上兩篇部落格介紹了計算單源最短路徑的Bellman-Ford演算法和Dijkstra演算法。Bellman-Ford演算法適用於任何有向圖,即使圖中包含負環路,它還能報告此問題。Dijkstra演算法執行速度比Bellman-Ford演算法要快,但是其要求圖中不能包含負權重

最小生成樹-MST演算法程式碼實現

師兄發了一篇CCF-C類的文章,但是那個會議在澳洲排名屬於B類,他說他在CVPR的某一篇文章的基礎之上用了MST演算法,避免了局部最優解,找到了全域性最優解,實驗結果比原文好很多,整個文章加實驗前後僅僅做了2周。真是羨煞旁人也。。其實在機器學習當中,經常會遇到區域性最優解的

redis配置文件實現主從同步切換

redis redis主從 redis配置文件詳解及實現主從同步切換redis復制Redis復制很簡單易用,它通過配置允許slave Redis Servers或者Master Servers的復制品。接下來有幾個關於redis復制的非常重要特性:一個Master可以有多個Slaves。Slaves能

微信和支付寶支付模式實現

配置 其余 logs https 朋友 一個 target 多租戶 對比   繼上篇《微信和支付寶支付模式詳解及實現》到現在已經有半年時間了,這期間不少朋友在公號留言支付相關的問題,最近正好也在處理公司支付相關的對接,打算寫這篇來做一個更進一步的介紹,同時根據主要的幾個支付

Show, attend and tell演算法原始碼

mark一下,感謝作者分享! https://blog.csdn.net/shenxiaolu1984/article/details/51493673 原論文:https://arxiv.org/pdf/1502.03044v2.pdf 原始碼:https://github.c

Kadane演算法求解最大子數列和問題

最大子數列和問題         給出一個數列,現在求其中一個子數列,要求是所有子數列的和的最大值。另外還有其他問法,例如給出一個數組,要求求出連續的元素和的最大值。可以一個例子來解釋: 假設有數列:[-1,2,3,-5,6,-2,4],那麼總共有

各種排序演算法C++實現

1.氣泡排序 時間複雜度 O ( n

KMP演算法各種應用

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

nfs實現全網備份

1.統一hosts cat /etc/hosts 172.16.1.5 lb01 172.16.1.6 lb02 172.16.1.7 web02 172.16.1.8 web01 172.16.1.51 db01 172.16.1.31 nfs01

二叉搜尋樹實現程式碼(BST)

概念 二叉搜尋樹(Binary Search Tree),又稱二叉排序樹,它或者是一顆空樹,或者具有如下性質的樹: 若它的左子樹不為空,則左子樹上所有節點的值都小於根節點的值 若它的右子樹不為空,則右子樹上所有節點的值都大於根節點的值 它的左右子樹也分別