1. 程式人生 > >機器學習(4)--層次聚類(hierarchical clustering)基本原理及實現簡單圖片分類

機器學習(4)--層次聚類(hierarchical clustering)基本原理及實現簡單圖片分類

關於層次聚類(hierarchical clustering)的基本步驟:
1、假設每個樣本為一類,計算每個類的距離,也就是相似度
2、把最近的兩個合為一新類,這樣類別數量就少了一個
3、重新新類與各個舊類(去了那兩個合併的類)之間的相似度;
4、迴圈重複2和3直到所有樣本點都歸為一類

這個計算的過程,相當於重構一個二叉樹,只是這個過程,是從樹葉-->樹枝-->樹幹的構建過程

本例將以14張圖片,做為樣本,進行聚類,點選這裡  下載圖片樣本

以下是使用我提供的圖片庫生成的分類結果,以及一張PS修後對程式碼中各變數的說明



當然,你也可以自己定義一個目錄,程式會讀取目錄下所有JPG圖片

如果你用了自己的圖片,在程式碼中的一此資料的變化說明,可能和使用產生的資料不同了,

同時,本文的主要目的是層次聚類(hierarchical clustering)的基本步驟,對於圖片的相似度的演算法並不完善,效果也並不是十分理想,不過如果你使用自己從手機中匯入的生活照,不同的場景大致還是能分類出來的

# -*- coding:utf-8 -*-

from PIL import ImageDraw,Image
import numpy as np
import os
import sys


nodeList = []#用於儲存所有的節點,包含圖片節點,與聚類後的節點
distance = {}#用於儲存所有每兩個節點的距離,資料格式{(node1.id,node2.id):30.0,(node2.id,node3.id):40.0}
class node:
    def __init__(self, data):
        '''每個樣本及樣本合併後節點的類
            data:接受兩種格式,
            1、當為字元(string)時,是圖片的地址,同時也表示這個節點就是圖片
            2、合併後的類,傳入的格式為(leftNode,rightNode) 即當前類表示合併後的新類,而對應的左右節點就是子節點
        '''
        self.id = len(nodeList)#設定一個ID,以nodeList當然長度為ID,在本例中ID本身沒太大用處,只是如果看程式碼時,有時要看指向時有點用
        self.parent = None # 指向合併後的類
        self.pos = None#用於最後繪製節構圖使用,賦值時為(x,y,w,h)格式
        if type(data) == type("") :
            '''節點為圖片'''
            self.imgData = Image.open(data)
            self.left = None
            self.right = None 
            self.level = 0    #圖片為最終的子節點,所有圖片的層級都為0,設定層級是為了最終繪製結構圖

            npTmp = np.array(self.imgData).reshape(-1,3) #將圖片資料轉化為numpy資料,shape為(高,寬,3),3為顏色通道
            npTmp = npTmp.reshape(-1,3)  #重新排列,shape為(高*寬,3)
            self.feature = npTmp.mean(axis=0)#計算RGB三個顏色通道均值

        else:
            '''節點為合成的新類'''
            self.imgData = None
            self.left = data[0]
            self.right = data[1]
            self.left.parent = self
            self.right.parent = self

            self.level = max(self.left.level,self.right.level) + 1 #層級為左右節高層級的級數+1
            self.feature = (self.left.feature + self.right.feature) / 2 #兩類的合成一類時,就是左右節點的feature相加/2
            
        #計算該類與每個其他類的距離,並存入distance
        for x in nodeList:
            distance[(x,self)] = np.sqrt(np.sum((x.feature - self.feature) ** 2))

        nodeList.append(self)#將本類加入nodeList變數

    def drawNode(self,img,draw,vLineLenght):
        #繪製結構圖
        if self.pos == None:return
        if self.left == None:
            #如果是圖片
            self.imgData.thumbnail((self.pos[2], self.pos[3]))
            img.paste(self.imgData,(self.pos[0], self.pos[1]))
            draw.line((int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1] - vLineLenght
                 , int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1])
                , fill=(255, 0, 0))
        else:
            #如果不是圖片
            draw.line((int(self.pos[0])
                 , self.pos[1]
                 , int(self.pos[0] + self.pos[2])
                 , self.pos[1])
                , fill=(255, 0, 0))

            draw.line((int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1]
                    , int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1] - self.pos[3])
                    , fill=(255, 0, 0))

def loadImg(path):
    '''path 圖片目錄,根據自己存的地方改寫'''
    files = None
    try :
        files = os.listdir(path)
    except:
        print('未正確讀取目錄:' + path + ',圖片目錄,請根據自己存的地方改寫,並保證沒有hierarchicalResult.jpg,該檔案為最後生成檔案')
        return None
    for i in files:

        if os.path.splitext(i)[1].lower() == '.jpg' and os.path.splitext(i)[0].lower() != 'hierarchicalresult':

            fileName = os.path.join(path,i)
            node(fileName)
    return os.path.join(path,'hierarchicalResult.jpg')

def getMinDistance():
    '''從distance中過濾出未分類的結點,並讀取最小的距離'''
    vars = list(filter(lambda x:x[0].parent == None and x[1].parent == None ,distance))
    minDist = vars[0]
    for x in vars:
        if minDist == None or distance[x] < distance[minDist]:
            minDist = x
    return minDist

def createTree():
    while len(list(filter(lambda x:x.parent == None ,nodeList))) > 1:#合併到最後時,只有一個類,只要有兩個以上未合併,就迴圈
        minDist = getMinDistance()
        #建立非圖片的節點,之所以把[1]做為左節點,因為繪圖時的需要,
        #在不斷的產生非圖片節點時,在nodeList的後面的一般是新節點,但繪圖時繪在左邊
        node((minDist[1],minDist[0])) 
    return nodeList[-1]#最後一個插入的節點就是要節點


def run():
    root = createTree()#建立樹結構

    #一句話的PYTON,實現二叉樹的左右根遍歷,通過通過遍歷,進行排序後,取出圖片,做為最底層的列印
    sortTree = lambda node:([] if node.left == None else sortTree(node.left)) + ([] if node.right == None else sortTree(node.right)) + [node]
    treeTmp = sortTree(root)
    treeTmp = list(filter(lambda x:x.left == None,treeTmp))#沒有左節點的,即為圖片

    thumbSize = 60 #縮圖的大小,,在60X60的小格內縮放
    thumbSpace = 20 #縮圖間距
    vLineLenght = 80 #上下節點,即每個level之間的高度

    imgWidth = len(treeTmp) * (thumbSize + thumbSpace)
    imgHeight = (root.level+1) * vLineLenght + thumbSize + thumbSpace*2
    img = Image.new('RGB', (imgWidth,imgHeight), (255, 255, 255))
    draw = ImageDraw.Draw(img)

    for item in enumerate(treeTmp):
        #為所有圖片增加繪圖資料
        x = item[0] * (thumbSize + thumbSpace) + thumbSpace / 2
        y = imgHeight - thumbSize - thumbSpace / 2 - ((item[1].parent.level - 1) * vLineLenght)
        w = item[1].imgData.width
        h = item[1].imgData.height
        if w > h:
            h = h / w * thumbSize
            w = thumbSize
        else:
            w = w / h * thumbSize
            h = thumbSize
            x+=(thumbSize - w) / 2
        item[1].pos = (int(x),int(y),int(w),int(h))
        item[1].drawNode(img,draw,vLineLenght)

    for x in range(1,root.level + 1):
        #為所有非圖片增加繪圖的資料
        items = list(filter(lambda i:i.level == x,nodeList))
        for item in items:
            x = item.left.pos[0] + (item.left.pos[2] / 2)
            w = item.right.pos[0] + (item.right.pos[2] / 2) - x
            y = item.left.pos[1] - (item.level - item.left.level) * vLineLenght
            h = ((item.parent.level if item.parent != None else item.level + 1) - item.level) * vLineLenght
            item.pos = (int(x),int(y),int(w),int(h))
            item.drawNode(img,draw,vLineLenght)
    img.save(resultFile)

resultFile = loadImg(r"E:\hierarchicalImgs")#讀取資料,並返回最後結果要儲存的檔名,目錄根據自己存的位置進行修改
if resultFile != 'None':
    run()
    print("結構圖生成成功,最終結構圖儲存於:" + resultFile)