1. 程式人生 > >【機器學習筆記12】聚類(k-means)

【機器學習筆記12】聚類(k-means)

K-means 演算法

演算法流程如下: (1)在樣本中選擇兩個點(也可以是若干個)作為種子點; (2)計算其餘各個樣本離該種子點的距離,並將其分為兩類; (3)將種子點移到(2)所分為的兩類的中間; (4)重複(2)(3)直到種子不再移動;

K-means 演算法程式(基於sklearn)
# -*- coding: utf-8 -*-
import numpy  as np
import matplotlib.pyplot as plt
import sklearn.datasets as ds
from sklearn.cluster import KMeans


def _test_kmeans():

    """
    初始化資料,原始資料data,和目標分類y
    """
    data, y = ds.make_blobs(300, n_features=2, centers=2, random_state=3)

    """
    對原始資料data進行2分類的,k均值聚類
    """
    model = KMeans(n_clusters=2, init='k-means++')

    #y_pre為根據聚類演算法得到的分類結果
    y_pre = model.fit_predict(data)

    plt.figure(figsize=(5, 6), facecolor='w')
    plt.subplot(211)
    plt.title('origin classfication')
    plt.scatter(data[:, 0], data[:, 1], c=y, s=30, edgecolors='none')

    plt.subplot(212)
    plt.title('k-means classfication')
    plt.scatter(data[:, 0], data[:, 1], c=y_pre, s=30, edgecolors='none')

    plt.show()

    pass

"""
說明:

K均值程式碼實現,對應的筆記《09.聚類(K-means)》

作者:fredric

日期:2018-8-2

"""
if __name__ == "__main__":

    _test_kmeans()

在這裡插入圖片描述