java實現k-means演算法(用的鳶尾花iris的資料集,從mysq資料庫中讀取資料)
阿新 • • 發佈:2019-02-13
k-means演算法又稱k-均值演算法,是機器學習聚類演算法中的一種,是一種基於形心的劃分方法,其中每個簇的中心都用簇中所有物件的均值來表示。其思想如下:
輸入:
- k:簇的數目;
- D:包含n個物件的資料集。
方法:
- 從D中隨機選擇幾個物件作為起始質心;
- 對每個質心,計算每個資料到各個質心的距離,並把這些點分配到離該質心最短的距離的簇;
- 對每個簇,計算簇中所有點的均值並將此均值作為新的質心;
- 將資料點按照新的中心重新聚類;
- 重複【步驟3】,直到質心不再發生變化(新的質心和原來的質心相等);
- 輸出聚類結果。
木羊的k-means演算法實現包括5各類。其中,DBConnection.java用於連線資料庫,SelectData.java用於從資料庫裡讀取資料,
DBConnection.java
資料集獲取,在機器學習資料集獲取官方網站UCI中點選開啟連結,木羊已經把該資料集從txt文件中插入到資料庫,並去除了最後一列(花類別)。讀者若不熟悉資料庫的讀寫,請百度。若木羊有時間,會在後面的博文中補充把txt文件內容讀到資料庫中的內容。
<span style="font-size:18px;">package db; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; /** * * 資料庫連線類 * */ public class DBConnection { public static final String driver = "com.mysql.jdbc.Driver"; public static final String url = "jdbc:mysql://localhost:3306/mydb"; public static final String user = "root"; public static final String pwd = "123"; public static Connection dBConnection() { Connection con = null; try { // 載入mysql驅動器 Class.forName(driver); // 建立資料庫連線 con = DriverManager.getConnection(url, user, pwd); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block System.out.println("載入驅動器失敗"); e.printStackTrace(); } catch (SQLException e) { // TODO Auto-generated catch block System.out.println("註冊驅動器失敗"); e.printStackTrace(); } return con; } }</span>
資料庫中的資料欄位如下(共有150條資料):
SelectData.java
package dao; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import model.Point; import db.DBConnection; /** * * 取出資料 * * @return pointList * */ public class SelectData { public static final String SELECT = "select* from iris_Kmeans"; public ArrayList<Point> getPoints() throws SQLException { ArrayList<Point> pointsList = new ArrayList<Point>(); Connection con = DBConnection.dBConnection(); ResultSet rs; // 建立一個PreparedStatement物件 PreparedStatement pstmt = con.prepareStatement(SELECT); rs = pstmt.executeQuery(); while (rs.next()) { Point point = new Point(); point.setX(rs.getDouble(2)); point.setY(rs.getDouble(3)); point.setZ(rs.getDouble(4)); point.setW(rs.getDouble(5)); pointsList.add(point); } System.out.println("資料集: " + pointsList); pstmt.close(); rs.close(); con.close(); return pointsList; } }
Point.java
此處要注意重寫equal和hashcode方法以便後面質心的比較。
package model;
public class Point {
private double x;
private double y;
private double z;
private double w;
public double getX() {
return x;
}
public void setX(double x) {
this.x = x;
}
public double getY() {
return y;
}
public void setY(double y) {
this.y = y;
}
public double getZ() {
return z;
}
public void setZ(double z) {
this.z = z;
}
public double getW() {
return w;
}
public void setW(double w) {
this.w = w;
}
public Point() {
}
public Point(double x, double y, double z, double w) {
super();
this.x = x;
this.y = y;
this.z = z;
this.w = w;
}
@Override
public String toString() {
return "Point [ x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + "]";
}
@Override
public boolean equals(Object obj) {
Point point = (Point) obj;
if (this.getX() == point.getX() && this.getY() == point.getY()
&& this.getZ() == point.getZ() && this.getW() == point.getW()) {
return true;
}
return false;
}
@Override
public int hashCode() {
return (int) (x + y + z + w);
}
}
ManagePoint.java
該類包含了3個方法,分別用於計算兩個點的歐氏距離,比較前後兩個質心是否相同,更新質心。
package util;
import java.util.ArrayList;
import java.util.Map;
import model.Point;
public class ManagePoint {
/**
*
* 計算兩點之間的距離
*
* @param p
* 第一個點
* @param q
* 第二個點
* @return distance
*
*/
public double getDistance(Point p, Point q) {
double dx = p.getX() - q.getX();
double dy = p.getY() - q.getY();
double dz = p.getZ() - q.getZ();
double dw = p.getW() - q.getW();
double distance = Math.sqrt(dx * dx + dy * dy + dz * dz + dw * dw);
return distance;
}
/**
* 判斷前後兩個質心是否相同
*
* @param nowCenterCluster
* 現在的質心
* @param lastCenterCluster
* 上一次的質心
* @return boolean
*
*/
public boolean isEqual(Map<Point, ArrayList<Point>> lastCenterCluster,
Map<Point, ArrayList<Point>> nowCenterCluster) {
boolean contain = false;
if (lastCenterCluster == null)
return false;
else {
for (Point point : nowCenterCluster.keySet()) {
contain = lastCenterCluster.containsKey(point);
}
if (contain)
return true;
}
return false;
}
/**
*
* 計算新的質心
*
* @param value
* map中的值,存放簇中的所有點
* @return point
*
*/
public Point getNewCenter(ArrayList<Point> value) {
double sumX = 0, sumY = 0, sumZ = 0, sumW = 0;
for (Point point : value) {
sumX += point.getX();
sumY += point.getY();
sumZ += point.getZ();
sumW += point.getW();
}
System.out.println("新的質心: (" + sumX / value.size() + "," + sumY
/ value.size() + "," + sumZ / value.size() + "," + sumW
/ value.size() + ")");
Point point = new Point();
point.setX(sumX / value.size());
point.setY(sumY / value.size());
point.setZ(sumZ / value.size());
point.setW(sumW / value.size());
return point;
}
}
Kmeans.java
木羊把簇存在hashmap裡,其中key存放該簇的質心,value存放該簇的所有點。特別注意的是,為了使最終聚類相對較理想,隨機選擇的三個初始質心應該在[0-50)、[50-100)、[100-150]三個區間內。
package util;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import model.Point;
import dao.SelectData;
public class Kmeans {
public Map<Point, ArrayList<Point>> executeKmeans(int k) {
ArrayList<Point> dataList = new ArrayList<Point>();// 存放原始資料
Map<Point, ArrayList<Point>> nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();// 當前質心及其簇內的點
Map<Point, ArrayList<Point>> lastCenterClusterMap = null;// 上一個質心及其簇內的點
try {
dataList = new SelectData().getPoints();
// 隨機建立K個點作為起始質心
Random rd = new Random();
int[] initIndex = { 50, 50, 50 };
int[] tempIndex = { 0, 50, 100 };
System.out.println("起始質心下標: ");
for (int i = 0; i < k; i++) {
int index = rd.nextInt(initIndex[i]) + tempIndex[i];
System.out.println("第" + (i + 1) + "個 : " + index);
nowCenterClusterMap.put(dataList.get(index),
new ArrayList<Point>());
}
// 輸出起始質心
System.out.println("起始質心: ");
for (Point point : nowCenterClusterMap.keySet())
System.out.println("key: " + point);
// 將資料點point加入配到離其最近的map的value中
ManagePoint managePoint = new ManagePoint();
while (true) {
for (Point point : dataList) {
double shortestDistance = Double.MAX_VALUE;// 初始化最短距離為Double的最大值
Point key = null;
for (Entry<Point, ArrayList<Point>> entry : nowCenterClusterMap
.entrySet()) {
// 計算質心與各點間的距離
double distance = managePoint.getDistance(
entry.getKey(), point);
if (distance < shortestDistance) {
shortestDistance = distance;
key = entry.getKey();
}
}
nowCenterClusterMap.get(key).add(point);
}
// 如果新的質心與上次的質心相等,則退出整個迴圈
if (managePoint.isEqual(lastCenterClusterMap,
nowCenterClusterMap)) {
System.out.println("相等了。");
break;
}
// 更新質心
lastCenterClusterMap = nowCenterClusterMap;
nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();
System.out.println("------------------------------------------------------------------");
for (Entry<Point, ArrayList<Point>> entry : lastCenterClusterMap
.entrySet()) {
nowCenterClusterMap.put(
managePoint.getNewCenter(entry.getValue()),
new ArrayList<Point>());
}
}
} catch (SQLException e) {
// TODO Auto-generated catch block
System.out.println("資料庫操作失敗");
e.printStackTrace();
}
return nowCenterClusterMap;
}
public static void main(String[] args) {
int K = 3;// 分為三個類
Map<Point, ArrayList<Point>> result = new Kmeans().executeKmeans(K);
// 輸出分類
System.out.println("===========聚類結果: ============");
for (Entry<Point, ArrayList<Point>> entry : result.entrySet()) {
System.out.println("\n" + "穩定的質心: " + entry.getKey());
System.out.println("該簇的大小: " + entry.getValue().size());
System.out.println("簇裡的點:" + entry.getValue());
}
}
}
以上程式碼均從MyEclipse上覆制貼上而來,親測可執行,結果如下:
經測試,無論初始質心被隨機選擇成哪3個,最終穩定的質心都不變。
(歡迎討論。程式碼尚有不完善之處,請多多指教。轉載請註明出處。)