EM 演算法-對鳶尾花資料進行聚類
阿新 • • 發佈:2020-12-14
> **公號:碼農充電站pro**
> **主頁:**
之前介紹過[K 均值演算法](https://www.cnblogs.com/codeshell/p/14084190.html),它是一種聚類演算法。今天介紹**EM 演算法**,它也是聚類演算法,但比**K 均值**演算法更加靈活強大。
**EM** 的全稱為 **Expectation Maximization**,中文為**期望最大化**演算法,它是一個不斷**觀察和調整**的過程。
### 1,和麵過程
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/2020121023550421.png?)
我們先來看一下和麵的過程。
通常情況下,如果你事先不知道面與水的比例,和麵過程可能是下面這樣:
1. 先放入一些面和水。
2. 將麵糰揉拌均勻。
3. 觀察麵糰的稀稠程度:如果麵糰比較稀,則加入少許面;如果麵糰比較稠,則加入少許水。
4. 如此往復第2,3步驟,直到麵糰的稀稠程度達到預期。
這個和麵過程,就是一個**EM** 過程:
- 先加入一些面和水,將麵糰揉拌均勻,並觀察麵糰的稀稠程度。這是**E** 過程。
- 不斷的加入水和麵(調整水和麵的比例),直到達到預期麵糰程度。這是**M** 過程。
### 2,再看K 均值演算法
在介紹[K 均值](https://www.cnblogs.com/codeshell/p/14084190.html) 聚類演算法時,展示過一個給二維座標點進行聚類的例子。
我們再來看一下這個例子,如下圖:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211095215492.png?)
上圖是一個聚類的過程,共有6 個步驟:
1. 初始時散點(綠色點)的分佈。
2. 隨機選出兩個中心點的位置,`紅色x` 和`藍色x`。
3. 計算所有散點分別到`紅色x` 和`藍色x`的距離,距離`紅色x` 近的點標紅色,距離`藍色x`近的點標藍色。
4. 重新計算兩個中心點的位置,兩個中心點分別移動到新的位置。
5. 重新計算所有散點分別到`紅色x` 和`藍色x`的距離,距離`紅色x` 近的點標紅色,距離`藍色x`近的點標藍色。
6. 再次計算兩個中心點的位置,兩個中心點分別移動到新的位置。中心點的位置幾乎不再變化,聚類結束。
經過以上步驟就完成了一個聚類過程。
實際上,**K 均值**演算法也是一個**EM** 過程:
- 確定當前各類中心點的位置,並計算各個散點到現有的中心點的距離。這是**E** 過程。
- 將各個散點歸屬到各個類中,重新計算各個類的中心點,直到各類的中心點不再改變。這是**M** 過程。
### 3,EM 演算法
將二維資料點的聚類過程,擴充套件為一般性的聚類問題,**EM** 演算法是這樣一個模型:對於待分類的資料點,**EM** 演算法讓計算機通過一個不斷迭代的過程,來構建一個分類模型。
**EM 演算法**分為兩個過程:
- **E 過程**:根據現有的模型,計算各個資料輸入到模型中的計算結果,這稱為**期望值計算過程**,即 **E** 過程。
- **M 過程**:重新計算模型引數,以最大化期望值,這稱為**最大化過程**,即**M** 過程。
以二維資料點的聚類過程為例,我們定義:
- 同一類中各個點到該類中心的平均距離為 **d**;
- 不同類之間的平均距離為 **D**。
那麼二維資料點聚類的**M** 過程,就是尋求最大化的**D** 和 **-d**。我們希望的聚類結果是,同一類的點距離較近,不同類之間距離較遠。
**EM 演算法**不是單個演算法,而是一類演算法。只要滿足**EM** 這兩個過程的演算法都可以被稱為**EM 演算法**。常見的**EM 演算法**有**GMM 高斯混合模型**和**HMM 隱馬爾科夫模型**。
### 4,最大似然估計
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211165648336.png?)
高等數學中有一門課叫做《概率論與數理統計》,其中講到了**引數估計**。
**統計推斷**是數理統計的重要組成部分,它是指利用來自總體的樣本提供的資訊,對總體的某些特徵進行估計或推斷,從而認識整體。
**統計推斷**分為兩大類:**引數估計**和**假設檢驗**。
我們假設,對於某個資料集,其分佈函式的**基本形式**已知,但其中含有一個或多個**未知引數**。
**引數估計**就是討論如何根據來自總體的樣本提供的資訊對未知引數做出估計。引數估計包括**點估計**和**區間估計**。其中,點估計中有兩種方法:**矩估計法**和**最大似然估計法**。
**最大似然估計**是一種通過已知結果,估計**未知引數**的方法。
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211172724453.png)
### 5,EM 演算法原理
**EM 演算法**使用的是**最大似然估計**的原理,它通過觀察樣本,來找出樣本的模型引數。
下面通過一個投硬幣的例子,來看下**EM 演算法**的計算過程。
這個例子來自《*Nature*》(自然)期刊的論文《*What is the expectation maximization algorithm?*》(什麼是期望最大化演算法?)。
假定有兩枚不同的硬幣 A 和 B,它們的重量分佈 **θ****A** 和 **θ****B** 是未知的,則可以通過拋擲硬幣,計算正反面各自出現的次數來估計**θ****A** 和 **θ****B**。
方法是在每一輪中隨機抽出一枚硬幣拋擲 10 次,同樣的過程執行 5 輪,根據這 50 次投幣的結果來計算 **θ****A** 和 **θ****B** 的**最大似然估計**。
投擲硬幣的過程,記錄如下:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211194515456.png?)
第1 到5 次分別投擲的硬幣是 **B,A,A,B,A**。**H** 代表正面,**T** 代表負面。將上圖轉化為表格,如下:
| 次數 | 硬幣 | 正面數 | 負面數 |
|--|--|--|--|
| 1 | B | 5 | 5 |
| 2 | A | 9 | 1 |
| 3 | A | 8 | 2 |
| 4 | B | 4 | 6 |
| 5 | A | 7 | 3 |
通過這個表格,可以直接計算 **θ****A** 和 **θ****B**,如下:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/2020121119525944.png)
顯然,如果**知道**每次投擲的硬幣是A 還是B,那麼計算**θ****A** 和 **θ****B** 是非常簡單的。
但是,如果**不知道**每次投擲的硬幣是A 還是B,該如何計算**θ****A** 和 **θ****B** 呢?
此時我們將上面表格中的**硬幣**一列隱藏起來,這時**硬幣**就是**隱變數**。所以我們只知道如下資料:
| 次數 | 正面數 | 負面數 |
|--|--|--|
| 1 | 5 | 5 |
| 2 | 9 | 1 |
| 3 | 8 | 2 |
| 4 | 4 | 6 |
| 5 | 7 | 3 |
這時想要計算 **θ****A** 和 **θ****B**,就要用**最大似然估計**的原理。
計算過程如下圖:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211195945957.png?)
***第一步***
先為 **θ****A** 和 **θ****B** 設定一個初始值,比如 **θ****A** = 0.6,**θ****B** = 0.5。
***第二步***
我們知道每一輪投幣的正 **/** 負面的次數:
- 第1輪:5 正 5 負,計算出現這種結果的概率:
- 如果是A 硬幣,那麼P(H5T5|A) = 0.6^5 * 0.4^5
- 如果是B 硬幣,那麼P(H5T5|B) = 0.5^5 * 0.5^5
- 將 P(H5T5|A) 和 P(H5T5|B) 歸一化處理,可得:
- P(H5T5|A) = **0.45**,P(H5T5|B) = **0.55**
- 第2輪:9 正 1 負,計算出現這種結果的概率:
- 如果是A 硬幣,那麼P(H9T1|A) = 0.6^9 * 0.4^1
- 如果是B 硬幣,那麼P(H9T1|B) = 0.5^9 * 0.5^1
- 將 P(H9T1|A) 和 P(H9T1|B) 歸一化處理,可得:
- P(H9T1|A) = **0.8**,P(H9T1|B) = **0.2**
- 第3輪:8 正 2 負,計算出現這種結果的概率:
- 如果是A 硬幣,那麼P(H8T2|A) = 0.6^8 * 0.4^2
- 如果是B 硬幣,那麼P(H8T2|B) = 0.5^8 * 0.5^2
- 將 P(H8T2|A) 和 P(H8T2|B) 歸一化處理,可得:
- P(H8T2|A) = **0.73**,P(H8T2|B) = **0.27**
- 第4輪:4 正 6 負,計算出現這種結果的概率:
- 如果是A 硬幣,那麼P(H4T6|A) = 0.6^4 * 0.4^6
- 如果是B 硬幣,那麼P(H4T6|B) = 0.5^4 * 0.5^6
- 將 P(H4T6|A) 和 P(H4T6|B) 歸一化處理,可得:
- P(H4T6|A) = **0.35**,P(H4T6|B) = **0.65**
- 第5輪:7 正 3 負,計算出現這種結果的概率:
- 如果是A 硬幣,那麼P(H7T3|A) = 0.6^7 * 0.4^3
- 如果是B 硬幣,那麼P(H7T3|B) = 0.5^7 * 0.5^3
- 將 P(H7T3|A) 和 P(H7T3|B) 歸一化處理,可得:
- P(H7T3|A) = **0.65**,P(H7T3|B) = **0.35**
然後,根據每一輪的 P(HmTn|A) 和 P(HmTn|B),可以計算出每一輪的正 **/** 負面次數。
> **m** 為正面次數,**n** 為負面次數。
對於硬幣A,結果如下:
| 輪數 | P(HmTn\|A) | m | n | 正面數 | 負面數 |
|--|--|--|--|--|--|
| 1 | 0.45 | 5 | 5 | 0.45*5=**2.2** | 0.45*5=**2.2** |
| 2 | 0.8 | 9 | 1 | 0.8*9=**7.2** | 0.8*1=**0.8** |
| 3 | 0.73 | 8 | 2 | 0.73*8=**5.9** | 0.73*2=**1.5** |
| 4 | 0.35 | 4 | 6 | 0.35*4=**1.4** | 0.35*6=**2.1** |
| 5 | 0.65 | 7 | 3 | 0.65*7=**4.5** | 0.65*3=**1.9** |
| **總計** |- | - | - |**21.3** | **8.6** |
對於硬幣B,結果如下:
| 輪數 | P(HmTn\|B) | m | n | 正面數 | 負面數 |
|--|--|--|--|--|--|
| 1 | 0.55 |5|5| 0.55*5=**2.8** | 0.55*5=**2.8** |
| 2 | 0.2 |9|1|0.2*9=**1.8** | 0.2*1=**0.2** |
| 3 | 0.27 |8|2|0.27*8=**2.1** | 0.27*2=**0.5** |
| 4 |0.65 |4|6|0.65*4=**2.6** | 0.65*6=**3.9** |
| 5 | 0.35 |7|3| 0.35*7=**2.5** | 0.35*3=**1.1** |
| **總計** |-|-|-| **11.7** | **8.4** |
***第三步***
根據上面兩個表格,可以得出(第1次迭代的結果) **θ****A** 和 **θ****B**:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211210914660.png)
根據這個估計值,再次回到第一步去計算。
如此往復第**一、二、三**步,經過10次迭代之後,**θ****A** 和 **θ****B** 的估計值為:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211211150924.png)
最終,**θ****A** 和 **θ****B** 將收斂到一個幾乎不變的值,此時迭代結束。這樣我們就求解出了**θ****A** 和 **θ****B** 的最大似然估計值。
我們將上述過程中,第一步稱為**初始化引數**,第二步稱為**觀察預期**,第三步稱為**重新估計引數**。
第**一、二**步為**E** 過程,第**三**步為**M** 過程,這就是**EM 演算法**的過程。
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201211212444598.png?)
如果我們有一個待聚類的資料集,我們把潛在的類別當做**隱變數**,樣本當做**觀察值**,這樣就可以把聚類問題轉化成**引數估計**問題。這就是**EM 聚類**的原理。
### 6,硬聚類與軟聚類
與 **K 均值**演算法相比,**K 均值**演算法是通過**距離**來區分樣本之間的差別,且每個樣本在計算的時候只能屬於一個分類,我們稱之為**硬聚類**演算法。
而 **EM 聚類**在求解的過程中,實際上每個樣本都有一定的概率和每個聚類相關,這叫做**軟聚類**演算法。
### 7,EM 聚類的缺點
EM 聚類演算法存在兩個比較明顯的問題。
第一個問題是,**EM 演算法**計算複雜,收斂較慢,不太適合大規模資料集和高維資料。
第二個問題是,**EM 演算法**不一定能給出**全域性最優解**:
- 當優化的**目標函式**是**凸函式**時,一定可以得到**全域性最優解**。
- 當優化的目標函式不是凸函式時,可能會得到**區域性最優解**,而非全域性最優解。
### 8,GMM 高斯混合模型
上文中介紹過,常見的**EM 演算法**有**GMM 高斯混合模型**和**HMM 隱馬爾科夫模型**。這裡主要介紹**GMM 高斯混合模型**的實現。
**sklearn** 庫的**mixture** 模組中的[GaussianMixture](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) 類是[GMM 演算法](https://scikit-learn.org/stable/modules/mixture.html)的實現。
先來看下 **GaussianMixture** 類的原型:
```python
GaussianMixture(
n_components=1,
covariance_type='full',
tol=0.001,
reg_covar=1e-06,
max_iter=100,
n_init=1,
init_params='kmeans',
weights_init=None,
means_init=None,
precisions_init=None,
random_state=None,
warm_start=False,
verbose=0,
verbose_interval=10)
```
這裡介紹幾個重要的引數:
- n_components:代表高斯混合模型的個數,也就是我們要聚類的個數,預設值為 1。
- covariance_type:代表協方差型別。一個高斯混合模型的分佈是由**均值向量**和**協方差矩陣**決定的,所以協方差的型別也代表了不同的高斯混合模型的特徵。協方差型別有 4 種取值:
- full,代表完全協方差,也就是元素都不為 0;
- tied,代表相同的完全協方差;
- diag,代表對角協方差,也就是對角不為 0,其餘為 0;
- spherical,代表球面協方差,非對角為 0,對角完全相同,呈現球面的特性。
- max_iter:代表最大迭代次數,預設值為 100。
### 9,對鳶尾花資料集聚類
在[《決策樹演算法-實戰篇-鳶尾花及波士頓房價預測》](https://www.cnblogs.com/codeshell/p/13984334.html)中我們介紹過[鳶尾花資料集](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/datasets/data/iris.csv)。這裡我們使用**GMM 演算法**對該資料進行**聚類**。
首先載入資料集:
```python
from sklearn.datasets import load_iris
iris = load_iris() # 載入資料集
features = iris.data # 獲取特徵集
labels = iris.target # 獲取目標集
```
在聚類演算法中,只需要特徵資料 `features`,而不需要目標資料`labels`,但可以使用 `labels` 對聚類的結果做驗證。
構造GMM聚類:
```python
from sklearn.mixture import GaussianMixture
# 原資料中有 3 個分類,所以這裡我們將 n_components 設定為 3
gmm = GaussianMixture(n_components=3, covariance_type='full')
```
對資料集進行聚類:
```python
prediction_labels = gmm.fit_predict(features)
```
檢視原始分類:
```python
>>> print(labels)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
```
檢視聚類結果:
```python
>>> print(prediction_labels)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 1 2 1
1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
```
對比原始分類和聚類結果,聚類結果中只有個別資料分類錯誤,我用`紅圈`標了出來:
![在這裡插入圖片描述](https://img-blog.csdnimg.cn/20201212102419198.png?)
### 10,評估聚類結果
我們可以使用 **Calinski-Harabaz 指標**對聚類結果進行評估。
**sklearn** 庫實現了該指標的計算,即 [calinski_harabasz_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.calinski_harabasz_score.html) 方法,該方法會計算出一個分值,分數越高,代表聚類效果越好,也就是相同類中的差異性小,不同類之間的差異性大。
下面對**鳶尾花資料集**的聚類結果進行評估,傳入**特徵資料**和**聚類結果**:
```python
>>> from sklearn.metrics import calinski_harabasz_score
>>> calinski_harabasz_score(features, prediction_labels)
481.78070899745234
```
我們也可以傳入**特徵資料**和**原始結果**:
```python
>>> calinski_harabasz_score(features, labels)
487.33087637489984
```
可以看到,對於**原始結果**計算出的分值是**487.33**,對於**預測結果**計算出的分值是**481.78**,相差並不多,說明預測結果還是不錯。
一般情況下,一個需要聚類的資料集並沒有目標資料,所以只能對**預測結果**進行評分。我們需要人工對聚類的含義結果進行分析。
### 11,總結
本篇文章主要介紹瞭如下內容:
- **EM 演算法**的過程及原理,介紹了一個**投擲硬幣**的例子。
- 硬聚類與軟聚類的區別。
- EM 聚類的缺點:
- 計算複雜度較大。
- 有可能得不到全域性最優解。
- 使用[GMM 演算法](https://scikit-learn.org/stable/modules/mixture.html)對鳶尾花資料進行聚類。
- 使用 **Calinski-Harabaz 指標**對聚類結果進行評估。
(本節完。)
---
**推薦閱讀:**
[K 均值演算法-如何讓資料自動分組](https://www.cnblogs.com/codeshell/p/14084190.html)
[Apriori 演算法-如何進行關聯規則挖掘](https://www.cnblogs.com/codeshell/p/14113600.html)
[PageRank 演算法-Google 如何給網頁排名](https://www.cnblogs.com/codeshell/p/14106948.html)
[資料變換-歸一化與標準化](https://www.cnblogs.com/codeshell/p/14060164.html)
[如何使用Python 進行資料視覺化](https://www.cnblogs.com/codeshell/p/14066350.html)
---
*歡迎關注作者公眾號,獲取更多技術乾貨。*
![碼農充電站pro](https://img-blog.csdnimg.cn/20200505082843773.png?#pic