機器學習(4)--層次聚類(hierarchical clustering)基本原理及實現簡單圖片分類
阿新 • • 發佈:2019-01-08
關於層次聚類(hierarchical clustering)的基本步驟:
1、假設每個樣本為一類,計算每個類的距離,也就是相似度
2、把最近的兩個合為一新類,這樣類別數量就少了一個
3、重新新類與各個舊類(去了那兩個合併的類)之間的相似度;
4、迴圈重複2和3直到所有樣本點都歸為一類
這個計算的過程,相當於重構一個二叉樹,只是這個過程,是從樹葉-->樹枝-->樹幹的構建過程
本例將以14張圖片,做為樣本,進行聚類,點選這裡 下載圖片樣本
1、假設每個樣本為一類,計算每個類的距離,也就是相似度
2、把最近的兩個合為一新類,這樣類別數量就少了一個
3、重新新類與各個舊類(去了那兩個合併的類)之間的相似度;
4、迴圈重複2和3直到所有樣本點都歸為一類
這個計算的過程,相當於重構一個二叉樹,只是這個過程,是從樹葉-->樹枝-->樹幹的構建過程
本例將以14張圖片,做為樣本,進行聚類,點選這裡 下載圖片樣本
以下是使用我提供的圖片庫生成的分類結果,以及一張PS修後對程式碼中各變數的說明
當然,你也可以自己定義一個目錄,程式會讀取目錄下所有JPG圖片
如果你用了自己的圖片,在程式碼中的一此資料的變化說明,可能和使用產生的資料不同了,同時,本文的主要目的是層次聚類(hierarchical clustering)的基本步驟,對於圖片的相似度的演算法並不完善,效果也並不是十分理想,不過如果你使用自己從手機中匯入的生活照,不同的場景大致還是能分類出來的
# -*- coding:utf-8 -*- from PIL import ImageDraw,Image import numpy as np import os import sys nodeList = []#用於儲存所有的節點,包含圖片節點,與聚類後的節點 distance = {}#用於儲存所有每兩個節點的距離,資料格式{(node1.id,node2.id):30.0,(node2.id,node3.id):40.0} class node: def __init__(self, data): '''每個樣本及樣本合併後節點的類 data:接受兩種格式, 1、當為字元(string)時,是圖片的地址,同時也表示這個節點就是圖片 2、合併後的類,傳入的格式為(leftNode,rightNode) 即當前類表示合併後的新類,而對應的左右節點就是子節點 ''' self.id = len(nodeList)#設定一個ID,以nodeList當然長度為ID,在本例中ID本身沒太大用處,只是如果看程式碼時,有時要看指向時有點用 self.parent = None # 指向合併後的類 self.pos = None#用於最後繪製節構圖使用,賦值時為(x,y,w,h)格式 if type(data) == type("") : '''節點為圖片''' self.imgData = Image.open(data) self.left = None self.right = None self.level = 0 #圖片為最終的子節點,所有圖片的層級都為0,設定層級是為了最終繪製結構圖 npTmp = np.array(self.imgData).reshape(-1,3) #將圖片資料轉化為numpy資料,shape為(高,寬,3),3為顏色通道 npTmp = npTmp.reshape(-1,3) #重新排列,shape為(高*寬,3) self.feature = npTmp.mean(axis=0)#計算RGB三個顏色通道均值 else: '''節點為合成的新類''' self.imgData = None self.left = data[0] self.right = data[1] self.left.parent = self self.right.parent = self self.level = max(self.left.level,self.right.level) + 1 #層級為左右節高層級的級數+1 self.feature = (self.left.feature + self.right.feature) / 2 #兩類的合成一類時,就是左右節點的feature相加/2 #計算該類與每個其他類的距離,並存入distance for x in nodeList: distance[(x,self)] = np.sqrt(np.sum((x.feature - self.feature) ** 2)) nodeList.append(self)#將本類加入nodeList變數 def drawNode(self,img,draw,vLineLenght): #繪製結構圖 if self.pos == None:return if self.left == None: #如果是圖片 self.imgData.thumbnail((self.pos[2], self.pos[3])) img.paste(self.imgData,(self.pos[0], self.pos[1])) draw.line((int(self.pos[0] + self.pos[2] / 2) , self.pos[1] - vLineLenght , int(self.pos[0] + self.pos[2] / 2) , self.pos[1]) , fill=(255, 0, 0)) else: #如果不是圖片 draw.line((int(self.pos[0]) , self.pos[1] , int(self.pos[0] + self.pos[2]) , self.pos[1]) , fill=(255, 0, 0)) draw.line((int(self.pos[0] + self.pos[2] / 2) , self.pos[1] , int(self.pos[0] + self.pos[2] / 2) , self.pos[1] - self.pos[3]) , fill=(255, 0, 0)) def loadImg(path): '''path 圖片目錄,根據自己存的地方改寫''' files = None try : files = os.listdir(path) except: print('未正確讀取目錄:' + path + ',圖片目錄,請根據自己存的地方改寫,並保證沒有hierarchicalResult.jpg,該檔案為最後生成檔案') return None for i in files: if os.path.splitext(i)[1].lower() == '.jpg' and os.path.splitext(i)[0].lower() != 'hierarchicalresult': fileName = os.path.join(path,i) node(fileName) return os.path.join(path,'hierarchicalResult.jpg') def getMinDistance(): '''從distance中過濾出未分類的結點,並讀取最小的距離''' vars = list(filter(lambda x:x[0].parent == None and x[1].parent == None ,distance)) minDist = vars[0] for x in vars: if minDist == None or distance[x] < distance[minDist]: minDist = x return minDist def createTree(): while len(list(filter(lambda x:x.parent == None ,nodeList))) > 1:#合併到最後時,只有一個類,只要有兩個以上未合併,就迴圈 minDist = getMinDistance() #建立非圖片的節點,之所以把[1]做為左節點,因為繪圖時的需要, #在不斷的產生非圖片節點時,在nodeList的後面的一般是新節點,但繪圖時繪在左邊 node((minDist[1],minDist[0])) return nodeList[-1]#最後一個插入的節點就是要節點 def run(): root = createTree()#建立樹結構 #一句話的PYTON,實現二叉樹的左右根遍歷,通過通過遍歷,進行排序後,取出圖片,做為最底層的列印 sortTree = lambda node:([] if node.left == None else sortTree(node.left)) + ([] if node.right == None else sortTree(node.right)) + [node] treeTmp = sortTree(root) treeTmp = list(filter(lambda x:x.left == None,treeTmp))#沒有左節點的,即為圖片 thumbSize = 60 #縮圖的大小,,在60X60的小格內縮放 thumbSpace = 20 #縮圖間距 vLineLenght = 80 #上下節點,即每個level之間的高度 imgWidth = len(treeTmp) * (thumbSize + thumbSpace) imgHeight = (root.level+1) * vLineLenght + thumbSize + thumbSpace*2 img = Image.new('RGB', (imgWidth,imgHeight), (255, 255, 255)) draw = ImageDraw.Draw(img) for item in enumerate(treeTmp): #為所有圖片增加繪圖資料 x = item[0] * (thumbSize + thumbSpace) + thumbSpace / 2 y = imgHeight - thumbSize - thumbSpace / 2 - ((item[1].parent.level - 1) * vLineLenght) w = item[1].imgData.width h = item[1].imgData.height if w > h: h = h / w * thumbSize w = thumbSize else: w = w / h * thumbSize h = thumbSize x+=(thumbSize - w) / 2 item[1].pos = (int(x),int(y),int(w),int(h)) item[1].drawNode(img,draw,vLineLenght) for x in range(1,root.level + 1): #為所有非圖片增加繪圖的資料 items = list(filter(lambda i:i.level == x,nodeList)) for item in items: x = item.left.pos[0] + (item.left.pos[2] / 2) w = item.right.pos[0] + (item.right.pos[2] / 2) - x y = item.left.pos[1] - (item.level - item.left.level) * vLineLenght h = ((item.parent.level if item.parent != None else item.level + 1) - item.level) * vLineLenght item.pos = (int(x),int(y),int(w),int(h)) item.drawNode(img,draw,vLineLenght) img.save(resultFile) resultFile = loadImg(r"E:\hierarchicalImgs")#讀取資料,並返回最後結果要儲存的檔名,目錄根據自己存的位置進行修改 if resultFile != 'None': run() print("結構圖生成成功,最終結構圖儲存於:" + resultFile)