1. 程式人生 > >KNN分類演算法java實現

KNN分類演算法java實現

最近鄰分類演算法思想

KNN演算法的思想總結一下:就是在訓練集中資料和標籤已知的情況下,輸入測試資料,將測試資料的特徵與訓練集中對應的特徵進行相互比較,找到訓練集中與之最為相似的前K個數據,則該測試資料對應的類別就是K個數據中出現次數最多的那個分類,其演算法的描述為:

1)計算測試資料與各個訓練資料之間的距離;

2)按照距離的遞增關係進行排序;

3)選取距離最小的K個點;

4)確定前K個點所在類別的出現頻率;

5)返回前K個點中出現頻率最高的類別作為測試資料的預測分類。

Java程式碼實現

KNN.java程式碼

public class KNN {

          public static void main(String[] args

) {

                // 一、輸入所有已知點

                List<Point>dataList = creatDataSet();

                // 二、輸入未知點

                Point x = new Point(5, 1.2, 1.2);

                // 三、計算所有已知點到未知點的歐式距離,並根據距離對所有已知點排序

                CompareClass compare = new CompareClass();

                Set<Distance> distanceSet

= new TreeSet<Distance>(compare);

                for (Pointpoint : dataList) {

                     distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point,

                             x)));

                }

                // 四、選取最近的k個點

                double k = 5;

                /**

                 * 五、計算k個點所在分類出現的頻率

                 */

                // 1、計算每個分類所包含的點的個數

                List<Distance> distanceList= new ArrayList<Distance>(distanceSet);

                Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);

                // 2、計算頻率

                Map<String, Double> p = computeP(map, k);

                x.setType(maxP(p));

                System.out.println("未知點的型別為:"+x.getType());

            }

            // 歐式距離計算

            public static double oudistance(Point point1, Pointpoint2) {

                double temp = Math.pow(point1.getX() - point2.getX(), 2)

                         + Math.pow(point1.getY() - point2.getY(), 2);

                return Math.sqrt(temp);

            }

            // 找出最大頻率

            public static String maxP(Map<String,Double> map) {

                String key = null;

                double value = 0.0;

                for (Map.Entry<String, Double> entry : map.entrySet()) {

                     if (entry.getValue() > value) {

                         key = entry.getKey();

                         value = entry.getValue();

                     }

                }

                return key;

            }

            // 計算頻率

            public static Map<String,Double> computeP(Map<String, Integer> map,

                     double k) {

                Map<String, Double> p = new HashMap<String, Double>();

                for (Map.Entry<String, Integer> entry : map.entrySet()) {

                     p.put(entry.getKey(), entry.getValue() / k);

                }

                return p;

            }

            // 計算每個分類包含的點的個數

            public static Map<String,Integer> getNumberOfType(

                     List<Distance> listDistance, List<Point> listPoint, double k) {

                Map<String, Integer> map = new HashMap<String, Integer>();

                int i = 0;

                System.out.println("選取的k個點,由近及遠依次為:");

                for (Distance distance : listDistance) {

                     System.out.println("id" + distance.getId() + ",距離為:"

                             + distance.getDisatance());

                     long id = distance.getId();

                     // 通過id找到所屬型別,並存儲到HashMap

                     for (Point point : listPoint) {

                         if (point.getId() == id) {

                             if (map.get(point.getType()) != null)

                                map.put(point.getType(), map.get(point.getType()) + 1);

                             else {

                                 map.put(point.getType(), 1);

                             }

                         }

                     }

                     i++;

                     if (i >= k)

                         break;

                }

                return map;

            }

            public static ArrayList<Point> creatDataSet(){

                Point point1 = new Point(1, 1.0, 1.1, "A");

                Point point2 = new Point(2, 1.0, 1.0, "A");

                Point point3 = new Point(3, 1.0, 1.2, "A");

                Point point4 = new Point(4, 0, 0, "B");

                Point point5 = new Point(5, 0, 0.1, "B");

                Point point6 = new Point(6, 0, 0.2, "B");

                ArrayList<Point>dataList = new ArrayList<Point>();

                dataList.add(point1);

                dataList.add(point2);

                dataList.add(point3);

                dataList.add(point4);

                dataList.add(point5);

                dataList.add(point6);

                return dataList;

            }

}

類中涉及到的Point類,Distance類,比較裁判CompareClass類如下:

Point

public class Point {

    private long id;

    private double x;

    private double y;

    private String type;

    public Point(long id,double x, double y) {

        this.x =x;

        this.y =y;

        this.id =id;

    }

    public Point(long id,double x, double y, String type) {

        this.x =x;

        this.y =y;

        this.type= type;

        this.id =id;

    }

    //get、set方法省略

}

Distance

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

public class Distance {

    // 已知點id

    private long id;

    // 未知點id

    private long nid;

    // 二者之間的距離

    private double disatance;

    public Distance(long id, long nid, double disatance) {

        this.id = id;

        this.nid = nid;

        this.disatance = disatance;

    }

       //get、set方法省略

}

比較器CompareClass

1

2

3

4

5

6

7

8

9

import java.util.Comparator;

//比較器類

public class CompareClass implements Comparator<Distance>{

    public int compare(Distance d1, Distance d2) {

        return d1.getDisatance()>d2.getDisatance()?20 : -1;

    }

}

其中的計算Map<String,Double> typeAndDistance按照distance進行排序,也就是按照map的value進行排序。思路也可以用如下方法:

1.  public class Testing {  

2.    

3.      public static void main(String[] args) {  

4.    

5.          HashMap<String,Double> map = new HashMap<String,Double>();  

6.          ValueComparator bvc =  new ValueComparator(map);  

7.          TreeMap<String,Double> sorted_map = new TreeMap<String,Double>(bvc);  

8.    

9.          map.put("A",99.5);  

10.         map.put("B",67.4);  

11.         map.put("C",67.4);  

12.         map.put("D",67.3);  

13.   

14.         System.out.println("unsorted map: "+map);  

15.   

16.         sorted_map.putAll(map);  

17.   

18.         System.out.println("results: "+sorted_map);  

19.     }  

20. }  

21.   

22. class ValueComparator implements Comparator<String> {  

23.   

24.     Map<String, Double> base;  

25.     public ValueComparator(Map<String, Double> base) {  

26.         this.base = base;  

27.     }  

28.   

29.     // Note: this comparator imposes orderings that are inconsistent with equals.      

30.     public int compare(String a, String b) {  

31.         if (base.get(a) >= base.get(b)) {  

32.             return -1;  

33.         } else {  

34.             return 1;  

35.         } // returning 0 would merge keys  

36.     }  

37. }