1. 程式人生 > >java實現k-means演算法(用的鳶尾花iris的資料集,從mysq資料庫中讀取資料)

java實現k-means演算法(用的鳶尾花iris的資料集,從mysq資料庫中讀取資料)

k-means演算法又稱k-均值演算法,是機器學習聚類演算法中的一種,是一種基於形心的劃分方法,其中每個簇的中心都用簇中所有物件的均值來表示。其思想如下:

輸入:

  • k:簇的數目;
  • D:包含n個物件的資料集。
輸出:k個簇的集合。

方法:

  1. 從D中隨機選擇幾個物件作為起始質心;
  2. 對每個質心,計算每個資料到各個質心的距離,並把這些點分配到離該質心最短的距離的簇;
  3. 對每個簇,計算簇中所有點的均值並將此均值作為新的質心;
  4. 將資料點按照新的中心重新聚類;
  5. 重複【步驟3】,直到質心不再發生變化(新的質心和原來的質心相等);
  6. 輸出聚類結果。
演算法實現:

木羊的k-means演算法實現包括5各類。其中,DBConnection.java用於連線資料庫,SelectData.java用於從資料庫裡讀取資料,

Point.java存放點物件模型,ManagePoint.java是對點的操作,Kmeans.java是演算法的核心思想及主函式入口。以下分別給出各個類的詳細程式碼:

DBConnection.java

資料集獲取,在機器學習資料集獲取官方網站UCI中點選開啟連結,木羊已經把該資料集從txt文件中插入到資料庫,並去除了最後一列(花類別)。讀者若不熟悉資料庫的讀寫,請百度。若木羊有時間,會在後面的博文中補充把txt文件內容讀到資料庫中的內容。

<span style="font-size:18px;">package db;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;

/**
 * 
 * 資料庫連線類
 * 
 */
public class DBConnection {
	public static final String driver = "com.mysql.jdbc.Driver";
	public static final String url = "jdbc:mysql://localhost:3306/mydb";
	public static final String user = "root";
	public static final String pwd = "123";

	public static Connection dBConnection() {
		Connection con = null;
		try {
			// 載入mysql驅動器
			Class.forName(driver);
			// 建立資料庫連線
			con = DriverManager.getConnection(url, user, pwd);
		} catch (ClassNotFoundException e) {
			// TODO Auto-generated catch block
			System.out.println("載入驅動器失敗");
			e.printStackTrace();
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			System.out.println("註冊驅動器失敗");
			e.printStackTrace();
		}
		return con;
	}
}</span>

資料庫中的資料欄位如下(共有150條資料):


SelectData.java

package dao;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;

import model.Point;
import db.DBConnection;

/**
 * 
 * 取出資料
 * 
 * @return pointList
 * 
 */
public class SelectData {
	public static final String SELECT = "select* from iris_Kmeans";

	public ArrayList<Point> getPoints() throws SQLException {
		ArrayList<Point> pointsList = new ArrayList<Point>();
		Connection con = DBConnection.dBConnection();
		ResultSet rs;
		// 建立一個PreparedStatement物件
		PreparedStatement pstmt = con.prepareStatement(SELECT);
		rs = pstmt.executeQuery();
		while (rs.next()) {
			Point point = new Point();
			point.setX(rs.getDouble(2));
			point.setY(rs.getDouble(3));
			point.setZ(rs.getDouble(4));
			point.setW(rs.getDouble(5));
			pointsList.add(point);
		}
		System.out.println("資料集: " + pointsList);
		pstmt.close();
		rs.close();
		con.close();
		return pointsList;
	}
}

Point.java

此處要注意重寫equal和hashcode方法以便後面質心的比較。

package model;

public class Point {
	private double x;
	private double y;
	private double z;
	private double w;

	public double getX() {
		return x;
	}

	public void setX(double x) {
		this.x = x;
	}

	public double getY() {
		return y;
	}

	public void setY(double y) {
		this.y = y;
	}

	public double getZ() {
		return z;
	}

	public void setZ(double z) {
		this.z = z;
	}

	public double getW() {
		return w;
	}

	public void setW(double w) {
		this.w = w;
	}

	public Point() {

	}

	public Point(double x, double y, double z, double w) {
		super();
		this.x = x;
		this.y = y;
		this.z = z;
		this.w = w;
	}

	@Override
	public String toString() {
		return "Point [ x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + "]";
	}

	@Override
	public boolean equals(Object obj) {
		Point point = (Point) obj;
		if (this.getX() == point.getX() && this.getY() == point.getY()
				&& this.getZ() == point.getZ() && this.getW() == point.getW()) {
			return true;
		}
		return false;
	}

	@Override
	public int hashCode() {
		return (int) (x + y + z + w);
	}
}
ManagePoint.java

該類包含了3個方法,分別用於計算兩個點的歐氏距離,比較前後兩個質心是否相同,更新質心。

package util;

import java.util.ArrayList;
import java.util.Map;

import model.Point;

public class ManagePoint {
	/**
	 * 
	 * 計算兩點之間的距離
	 * 
	 * @param p
	 *            第一個點
	 * @param q
	 *            第二個點
	 * @return distance
	 * 
	 */
	public double getDistance(Point p, Point q) {
		double dx = p.getX() - q.getX();
		double dy = p.getY() - q.getY();
		double dz = p.getZ() - q.getZ();
		double dw = p.getW() - q.getW();
		double distance = Math.sqrt(dx * dx + dy * dy + dz * dz + dw * dw);
		return distance;
	}

	/**
	 * 判斷前後兩個質心是否相同
	 * 
	 * @param nowCenterCluster
	 *            現在的質心
	 * @param lastCenterCluster
	 *            上一次的質心
	 * @return boolean
	 * 
	 */

	public boolean isEqual(Map<Point, ArrayList<Point>> lastCenterCluster,
			Map<Point, ArrayList<Point>> nowCenterCluster) {
		boolean contain = false;
		if (lastCenterCluster == null)
			return false;
		else {
			for (Point point : nowCenterCluster.keySet()) {
				contain = lastCenterCluster.containsKey(point);
			}
			if (contain)
				return true;
		}
		return false;
	}

	/**
	 * 
	 * 計算新的質心
	 * 
	 * @param value
	 *            map中的值,存放簇中的所有點
	 * @return point
	 * 
	 */
	public Point getNewCenter(ArrayList<Point> value) {
		double sumX = 0, sumY = 0, sumZ = 0, sumW = 0;
		for (Point point : value) {
			sumX += point.getX();
			sumY += point.getY();
			sumZ += point.getZ();
			sumW += point.getW();
		}
		System.out.println("新的質心: (" + sumX / value.size() + "," + sumY
				/ value.size() + "," + sumZ / value.size() + "," + sumW
				/ value.size() + ")");
		Point point = new Point();
		point.setX(sumX / value.size());
		point.setY(sumY / value.size());
		point.setZ(sumZ / value.size());
		point.setW(sumW / value.size());
		return point;
	}
}

Kmeans.java

木羊把簇存在hashmap裡,其中key存放該簇的質心,value存放該簇的所有點。特別注意的是,為了使最終聚類相對較理想,隨機選擇的三個初始質心應該在[0-50)、[50-100)、[100-150]三個區間內。

package util;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;

import model.Point;
import dao.SelectData;

public class Kmeans {

	public Map<Point, ArrayList<Point>> executeKmeans(int k) {
		ArrayList<Point> dataList = new ArrayList<Point>();// 存放原始資料
		Map<Point, ArrayList<Point>> nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();// 當前質心及其簇內的點
		Map<Point, ArrayList<Point>> lastCenterClusterMap = null;// 上一個質心及其簇內的點
		try {
			dataList = new SelectData().getPoints();
			// 隨機建立K個點作為起始質心
			Random rd = new Random();
			int[] initIndex = { 50, 50, 50 };
			int[] tempIndex = { 0, 50, 100 };
			System.out.println("起始質心下標: ");
			for (int i = 0; i < k; i++) {
				int index = rd.nextInt(initIndex[i]) + tempIndex[i];
				System.out.println("第" + (i + 1) + "個 : " + index);
				nowCenterClusterMap.put(dataList.get(index),
						new ArrayList<Point>());
			}
			// 輸出起始質心
			System.out.println("起始質心: ");
			for (Point point : nowCenterClusterMap.keySet())
				System.out.println("key:  " + point);

			// 將資料點point加入配到離其最近的map的value中
			ManagePoint managePoint = new ManagePoint();
			while (true) {
				for (Point point : dataList) {
					double shortestDistance = Double.MAX_VALUE;// 初始化最短距離為Double的最大值
					Point key = null;
					for (Entry<Point, ArrayList<Point>> entry : nowCenterClusterMap
							.entrySet()) {
						// 計算質心與各點間的距離
						double distance = managePoint.getDistance(
								entry.getKey(), point);
						if (distance < shortestDistance) {
							shortestDistance = distance;
							key = entry.getKey();
						}
					}
					nowCenterClusterMap.get(key).add(point);
				}
				// 如果新的質心與上次的質心相等,則退出整個迴圈
				if (managePoint.isEqual(lastCenterClusterMap,
						nowCenterClusterMap)) {
					System.out.println("相等了。");
					break;
				}
				// 更新質心
				lastCenterClusterMap = nowCenterClusterMap;
				nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();
				System.out.println("------------------------------------------------------------------");
				for (Entry<Point, ArrayList<Point>> entry : lastCenterClusterMap
						.entrySet()) {
					nowCenterClusterMap.put(
							managePoint.getNewCenter(entry.getValue()),
							new ArrayList<Point>());
				}
			}
		} catch (SQLException e) {
			// TODO Auto-generated catch block
			System.out.println("資料庫操作失敗");
			e.printStackTrace();
		}
		return nowCenterClusterMap;
	}

	public static void main(String[] args) {
		int K = 3;// 分為三個類
		Map<Point, ArrayList<Point>> result = new Kmeans().executeKmeans(K);
		// 輸出分類
		System.out.println("===========聚類結果: ============");
		for (Entry<Point, ArrayList<Point>> entry : result.entrySet()) {
			System.out.println("\n" + "穩定的質心: " + entry.getKey());
			System.out.println("該簇的大小: " + entry.getValue().size());
			System.out.println("簇裡的點:" + entry.getValue());
		}
	}
}

以上程式碼均從MyEclipse上覆制貼上而來,親測可執行,結果如下:



經測試,無論初始質心被隨機選擇成哪3個,最終穩定的質心都不變。

(歡迎討論。程式碼尚有不完善之處,請多多指教。轉載請註明出處。)