1. 程式人生 > >使用Java實現K-Means聚類演算法

使用Java實現K-Means聚類演算法

第一次寫部落格,隨便寫寫。

關於K-Means介紹很多,還不清楚可以查一些相關資料。

個人對其實現步驟簡單總結為4步:

1.選出k值,隨機出k個起始質心點。 
 
2.分別計算每個點和k個起始質點之間的距離,就近歸類。 
 
3.最終中心點集可以劃分為k類,分別計算每類中新的中心點。 
 

4.重複2,3步驟對所有點進行歸類,如果當所有分類的質心點不再改變,則最終收斂。


下面貼程式碼。

1.入口類,基本讀取資料來源進行訓練然後輸出。 資料來源檔案和原始碼後面會補上。

package com.hyr.kmeans;

import au.com.bytecode.opencsv.CSVReader;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KmeansMain {

    public static void main(String[] args) throws IOException {
        // 讀取資料來源檔案
        CSVReader reader = new CSVReader(new FileReader("src/main/resources/data.csv")); // 資料來源
        FileWriter writer = new FileWriter("src/main/resources/out.csv");
        List<String[]> myEntries = reader.readAll(); // 6.8, 12.6

        // 轉換資料點集
        List<Point> points = new ArrayList<Point>(); // 資料點集
        for (String[] entry : myEntries) {
            points.add(new Point(Float.parseFloat(entry[0]), Float.parseFloat(entry[1])));
        }

        int k = 6; // K值
        int type = 1;
        KmeansModel model = Kmeans.run(points, k, type);

        writer.write("====================   K is " + model.getK() + " ,  Object Funcion Value is " + model.getOfv() + " ,  calc_distance_type is " + model.getCalc_distance_type() + "   ====================\n");
        int i = 0;
        for (Cluster cluster : model.getClusters()) {
            i++;
            writer.write("====================   classification " + i + "   ====================\n");
            for (Point point : cluster.getPoints()) {
                writer.write(point.toString() + "\n");
            }
            writer.write("\n");
            writer.write("centroid is " + cluster.getCentroid().toString());
            writer.write("\n\n");
        }

        writer.close();

    }

}


2.最終生成的模型類,也就是最終訓練好的結果。K值,計算的點距離型別以及object function value值。

package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;

public class KmeansModel {

    private List<Cluster> clusters = new ArrayList<Cluster>();
    private Double ofv;
    private int k;  // k值
    private int calc_distance_type;

    public KmeansModel(List<Cluster> clusters, Double ofv, int k, int calc_distance_type) {
        this.clusters = clusters;
        this.ofv = ofv;
        this.k = k;
        this.calc_distance_type = calc_distance_type;
    }

    public List<Cluster> getClusters() {
        return clusters;
    }

    public Double getOfv() {
        return ofv;
    }

    public int getK() {
        return k;
    }

    public int getCalc_distance_type() {
        return calc_distance_type;
    }
}

3.資料集點物件,包含點的維度,程式碼裡只給出了x軸,y軸二維。以及點的距離計算。通過型別選擇距離公式。給出了幾種常用的距離公式。

package com.hyr.kmeans;

public class Point {

    private Float x;     // x 軸
    private Float y;    // y 軸

    public Point(Float x, Float y) {
        this.x = x;
        this.y = y;
    }

    public Float getX() {
        return x;
    }

    public void setX(Float x) {
        this.x = x;
    }

    public Float getY() {
        return y;
    }

    public void setY(Float y) {
        this.y = y;
    }

    @Override
    public String toString() {
        return "Point{" +
                "x=" + x +
                ", y=" + y +
                '}';
    }

    /**
     * 計算距離
     *
     * @param centroid 質心點
     * @param type
     * @return
     */
    public Double calculateDistance(Point centroid, int type) {
        // TODO
        Double result = null;
        switch (type) {
            case 1:
                result = calcL1Distance(centroid);
                break;
            case 2:
                result = calcCanberraDistance(centroid);
                break;
            case 3:
                result = calcEuclidianDistance(centroid);
                break;
        }
        return result;
    }



    /*
            計算距離公式
     */

    private Double calcL1Distance(Point centroid) {
        double res = 0;
        res = Math.abs(getX() - centroid.getX()) + Math.abs(getY() - centroid.getY());
        return res / (double) 2;
    }

    private double calcEuclidianDistance(Point centroid) {
        return Math.sqrt(Math.pow((centroid.getX() - getX()), 2) + Math.pow((centroid.getY() - getY()), 2));
    }

    private double calcCanberraDistance(Point centroid) {
        double res = 0;
        res = Math.abs(getX() - centroid.getX()) / (Math.abs(getX()) + Math.abs(centroid.getX()))
                + Math.abs(getY() - centroid.getY()) / (Math.abs(getY()) + Math.abs(centroid.getY()));
        return res / (double) 2;
    }

    @Override
    public boolean equals(Object obj) {
        Point other = (Point) obj;
        if (getX().equals(other.getX()) && getY().equals(other.getY())) {
            return true;
        }
        return false;
    }
}

4.訓練後最終得到的分類。包含該分類的質點,屬於該分類的點集合該分類是否收斂。

package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;

public class Cluster {

    private List<Point> points = new ArrayList<Point>(); // 屬於該分類的點集
    private Point centroid; // 該分類的中心質點
    private boolean isConvergence = false;

    public Point getCentroid() {
        return centroid;
    }

    public void setCentroid(Point centroid) {
        this.centroid = centroid;
    }

    @Override
    public String toString() {
        return centroid.toString();
    }

    public List<Point> getPoints() {
        return points;
    }

    public void setPoints(List<Point> points) {
        this.points = points;
    }


    public void initPoint() {
        points.clear();
    }

    public boolean isConvergence() {
        return isConvergence;
    }

    public void setConvergence(boolean convergence) {
        isConvergence = convergence;
    }
}

5.K-Meams訓練類。按照上面所說四個步驟不斷進行訓練。

package com.hyr.kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Kmeans {

    /**
     * kmeans
     *
     * @param points 資料集
     * @param k      K值
     * @param k      計算距離方式
     */
    public static KmeansModel run(List<Point> points, int k, int type) {
        // 初始化質心點
        List<Cluster> clusters = initCentroides(points, k);

        while (!checkConvergence(clusters)) { // 所有分類是否全部收斂
            // 1.計算距離對每個點進行分類
            // 2.判斷質心點是否改變,未改變則該分類已經收斂
            // 3.重新生成質心點
            initClusters(clusters); // 重置分類中的點
            classifyPoint(points, clusters, type);// 計算距離進行分類
            recalcularCentroides(clusters); // 重新計算質心點
        }

        // 計算目標函式值
        Double ofv = calcularObjetiFuncionValue(clusters);

        KmeansModel kmeansModel = new KmeansModel(clusters, ofv, k, type);

        return kmeansModel;
    }

    /**
     * 初始化k個質心點
     *
     * @param points 點集
     * @param k      K值
     * @return 分類集合物件
     */
    private static List<Cluster> initCentroides(List<Point> points, Integer k) {
        List<Cluster> centroides = new ArrayList<Cluster>();

        // 求出資料集的範圍(找出所有點的x最小、最大和y最小、最大座標。)
        Float max_X = Float.NEGATIVE_INFINITY;
        Float max_Y = Float.NEGATIVE_INFINITY;
        Float min_X = Float.POSITIVE_INFINITY;
        Float min_Y = Float.POSITIVE_INFINITY;
        for (Point point : points) {
            max_X = max_X < point.getX() ? point.getX() : max_X;
            max_Y = max_Y < point.getY() ? point.getY() : max_Y;
            min_X = min_X > point.getX() ? point.getX() : min_X;
            min_Y = min_Y > point.getY() ? point.getY() : min_Y;
        }
        System.out.println("min_X" + min_X + ",max_X:" + max_X + ",min_Y" + min_Y + ",max_Y" + max_Y);

        // 在範圍內隨機初始化k個質心點
        Random random = new Random();
        // 隨機初始化k箇中心點
        for (int i = 0; i < k; i++) {
            float x = random.nextFloat() * (max_X - min_X) + min_X;
            float y = random.nextFloat() * (max_Y - min_Y) + min_X;
            Cluster c = new Cluster();
            Point centroide = new Point(x, y); // 初始化的隨機中心點
            c.setCentroid(centroide);
            centroides.add(c);
        }

        return centroides;
    }

    /**
     * 重新計算質心點
     *
     * @param clusters
     */
    private static void recalcularCentroides(List<Cluster> clusters) {
        for (Cluster c : clusters) {
            if (c.getPoints().isEmpty()) {
                c.setConvergence(true);
                continue;
            }

            // 求均值,作為新的質心點
            Float x;
            Float y;
            Float sum_x = 0f;
            Float sum_y = 0f;
            for (Point point : c.getPoints()) {
                sum_x += point.getX();
                sum_y += point.getY();
            }
            x = sum_x / c.getPoints().size();
            y = sum_y / c.getPoints().size();
            Point nuevoCentroide = new Point(x, y); // 新的質心點

            if (nuevoCentroide.equals(c.getCentroid())) { // 如果質心點不再改變 則該分類已經收斂
                c.setConvergence(true);
            } else {
                c.setCentroid(nuevoCentroide);
            }
        }
    }

    /**
     * 計算距離,對點集進行分類
     *
     * @param points   點集
     * @param clusters 分類
     * @param type     計算距離方式
     */
    private static void classifyPoint(List<Point> points, List<Cluster> clusters, int type) {
        for (Point point : points) {
            Cluster masCercano = clusters.get(0); // 該點計算距離後所屬的分類
            Double minDistancia = Double.MAX_VALUE; // 最小距離
            for (Cluster cluster : clusters) {
                Double distancia = point.calculateDistance(cluster.getCentroid(), type); // 點和每個分類質心點的距離
                if (minDistancia > distancia) { // 得到該點和k個質心點最小的距離
                    minDistancia = distancia;
                    masCercano = cluster; // 得到該點的分類
                }
            }
            masCercano.getPoints().add(point); // 將該點新增到距離最近的分類中
        }
    }

    private static void initClusters(List<Cluster> clusters) {
        for (Cluster cluster : clusters) {
            cluster.initPoint();
        }
    }

    /**
     * 檢查收斂
     *
     * @param clusters
     * @return
     */
    private static boolean checkConvergence(List<Cluster> clusters) {
        for (Cluster cluster : clusters) {
            if (!cluster.isConvergence()) {
                return false;
            }
        }
        return true;
    }

    /**
     * 計算目標函式值
     *
     * @param clusters
     * @return
     */
    private static Double calcularObjetiFuncionValue(List<Cluster> clusters) {
        Double ofv = 0d;

        for (Cluster cluster : clusters) {
            for (Point point : cluster.getPoints()) {
                int type = 1;
                ofv += point.calculateDistance(cluster.getCentroid(), type);
            }
        }

        return ofv;
    }
}


最終訓練結果:

====================   K is 6 ,  Object Funcion Value is 21.82857036590576 ,  calc_distance_type is 3   ====================
====================   classification 1   ====================
Point{x=3.5, y=12.5}

centroid is Point{x=3.5, y=12.5}

====================   classification 2   ====================
Point{x=6.8, y=12.6}
Point{x=7.8, y=12.2}
Point{x=8.2, y=11.1}
Point{x=9.6, y=11.1}

centroid is Point{x=8.1, y=11.75}

====================   classification 3   ====================
Point{x=4.4, y=6.5}
Point{x=4.8, y=1.1}
Point{x=5.3, y=6.4}
Point{x=6.6, y=7.7}
Point{x=8.2, y=4.5}
Point{x=8.4, y=6.9}
Point{x=9.0, y=3.4}

centroid is Point{x=6.671428, y=5.2142863}

====================   classification 4   ====================
Point{x=6.0, y=19.9}
Point{x=6.2, y=18.5}
Point{x=5.3, y=19.4}
Point{x=7.6, y=17.4}

centroid is Point{x=6.275, y=18.800001}

====================   classification 5   ====================
Point{x=0.8, y=9.8}
Point{x=1.2, y=11.6}
Point{x=2.8, y=9.6}
Point{x=3.8, y=9.9}

centroid is Point{x=2.15, y=10.225}

====================   classification 6   ====================
Point{x=6.1, y=14.3}

centroid is Point{x=6.1, y=14.3}



程式碼下載地址:

http://download.csdn.net/download/huangyueranbbc/10267041

github: 

https://github.com/huangyueranbbc/KmeansDemo