1. 程式人生 > >簡單易學的機器學習演算法——基於密度的聚類演算法DBSCAN

簡單易學的機器學習演算法——基於密度的聚類演算法DBSCAN

%% DBSCAN
clear all;
clc;

%% 匯入資料集
% data = load('testData.txt');
data = load('testData_2.txt');

% 定義引數Eps和MinPts
MinPts = 5;
Eps = epsilon(data, MinPts);

[m,n] = size(data);%得到資料的大小

x = [(1:m)' data];
[m,n] = size(x);%重新計算資料集的大小
types = zeros(1,m);%用於區分核心點1,邊界點0和噪音點-1
dealed = zeros(m,1);%用於判斷該點是否處理過,0表示未處理過
dis = calDistance(x(:,2:n));
number = 1;%用於標記類

%% 對每一個點進行處理
for i = 1:m
    %找到未處理的點
    if dealed(i) == 0
        xTemp = x(i,:);
        D = dis(i,:);%取得第i個點到其他所有點的距離
        ind = find(D<=Eps);%找到半徑Eps內的所有點
        
        %% 區分點的型別
        
        %邊界點
        if length(ind) > 1 && length(ind) < MinPts+1
            types(i) = 0;
            class(i) = 0;
        end
        %噪音點
        if length(ind) == 1
            types(i) = -1;
            class(i) = -1;
            dealed(i) = 1;
        end
        %核心點(此處是關鍵步驟)
        if length(ind) >= MinPts+1
            types(xTemp(1,1)) = 1;
            class(ind) = number;
            
            % 判斷核心點是否密度可達
            while ~isempty(ind)
                yTemp = x(ind(1),:);
                dealed(ind(1)) = 1;
                ind(1) = [];
                D = dis(yTemp(1,1),:);%找到與ind(1)之間的距離
                ind_1 = find(D<=Eps);
                
                if length(ind_1)>1%處理非噪音點
                    class(ind_1) = number;
                    if length(ind_1) >= MinPts+1
                        types(yTemp(1,1)) = 1;
                    else
                        types(yTemp(1,1)) = 0;
                    end
                    
                    for j=1:length(ind_1)
                       if dealed(ind_1(j)) == 0
                          dealed(ind_1(j)) = 1;
                          ind=[ind ind_1(j)];   
                          class(ind_1(j))=number;
                       end                    
                   end
                end
            end
            number = number + 1;
        end
    end
end

% 最後處理所有未分類的點為噪音點
ind_2 = find(class==0);
class(ind_2) = -1;
types(ind_2) = -1;

%% 畫出最終的聚類圖
hold on
for i = 1:m
    if class(i) == -1
        plot(data(i,1),data(i,2),'.r');
    elseif class(i) == 1
        if types(i) == 1
            plot(data(i,1),data(i,2),'+b');
        else
            plot(data(i,1),data(i,2),'.b');
        end
    elseif class(i) == 2
        if types(i) == 1
            plot(data(i,1),data(i,2),'+g');
        else
            plot(data(i,1),data(i,2),'.g');
        end
    elseif class(i) == 3
        if types(i) == 1
            plot(data(i,1),data(i,2),'+c');
        else
            plot(data(i,1),data(i,2),'.c');
        end
    else
        if types(i) == 1
            plot(data(i,1),data(i,2),'+k');
        else
            plot(data(i,1),data(i,2),'.k');
        end
    end
end
hold off

距離計算函式
%% 計算矩陣中點與點之間的距離
function [ dis ] = calDistance( x )
    [m,n] = size(x);
    dis = zeros(m,m);
    
    for i = 1:m
        for j = i:m
            %計算點i和點j之間的歐式距離
            tmp =0;
            for k = 1:n
                tmp = tmp+(x(i,k)-x(j,k)).^2;
            end
            dis(i,j) = sqrt(tmp);
            dis(j,i) = dis(i,j);
        end
    end
end

epsilon函式
function [Eps]=epsilon(x,k)

% Function: [Eps]=epsilon(x,k)
%
% Aim: 
% Analytical way of estimating neighborhood radius for DBSCAN
%
% Input: 
% x - data matrix (m,n); m-objects, n-variables
% k - number of objects in a neighborhood of an object
% (minimal number of objects considered as a cluster)



[m,n]=size(x);

Eps=((prod(max(x)-min(x))*k*gamma(.5*n+1))/(m*sqrt(pi.^n))).^(1/n);


最終的結果