KNN分類演算法java實現
最近鄰分類演算法思想
KNN演算法的思想總結一下:就是在訓練集中資料和標籤已知的情況下,輸入測試資料,將測試資料的特徵與訓練集中對應的特徵進行相互比較,找到訓練集中與之最為相似的前K個數據,則該測試資料對應的類別就是K個數據中出現次數最多的那個分類,其演算法的描述為:
1)計算測試資料與各個訓練資料之間的距離;
2)按照距離的遞增關係進行排序;
3)選取距離最小的K個點;
4)確定前K個點所在類別的出現頻率;
5)返回前K個點中出現頻率最高的類別作為測試資料的預測分類。
Java程式碼實現
KNN.java程式碼
public class KNN {
public static void main(String[] args
// 一、輸入所有已知點
List<Point>dataList = creatDataSet();
// 二、輸入未知點
Point x = new Point(5, 1.2, 1.2);
// 三、計算所有已知點到未知點的歐式距離,並根據距離對所有已知點排序
CompareClass compare = new CompareClass();
Set<Distance> distanceSet
for (Pointpoint : dataList) {
distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point,
x)));
}
// 四、選取最近的k個點
double k = 5;
/**
* 五、計算k個點所在分類出現的頻率
*/
// 1、計算每個分類所包含的點的個數
List<Distance> distanceList= new ArrayList<Distance>(distanceSet);
Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);
// 2、計算頻率
Map<String, Double> p = computeP(map, k);
x.setType(maxP(p));
System.out.println("未知點的型別為:"+x.getType());
}
// 歐式距離計算
public static double oudistance(Point point1, Pointpoint2) {
double temp = Math.pow(point1.getX() - point2.getX(), 2)
+ Math.pow(point1.getY() - point2.getY(), 2);
return Math.sqrt(temp);
}
// 找出最大頻率
public static String maxP(Map<String,Double> map) {
String key = null;
double value = 0.0;
for (Map.Entry<String, Double> entry : map.entrySet()) {
if (entry.getValue() > value) {
key = entry.getKey();
value = entry.getValue();
}
}
return key;
}
// 計算頻率
public static Map<String,Double> computeP(Map<String, Integer> map,
double k) {
Map<String, Double> p = new HashMap<String, Double>();
for (Map.Entry<String, Integer> entry : map.entrySet()) {
p.put(entry.getKey(), entry.getValue() / k);
}
return p;
}
// 計算每個分類包含的點的個數
public static Map<String,Integer> getNumberOfType(
List<Distance> listDistance, List<Point> listPoint, double k) {
Map<String, Integer> map = new HashMap<String, Integer>();
int i = 0;
System.out.println("選取的k個點,由近及遠依次為:");
for (Distance distance : listDistance) {
System.out.println("id為" + distance.getId() + ",距離為:"
+ distance.getDisatance());
long id = distance.getId();
// 通過id找到所屬型別,並存儲到HashMap中
for (Point point : listPoint) {
if (point.getId() == id) {
if (map.get(point.getType()) != null)
map.put(point.getType(), map.get(point.getType()) + 1);
else {
map.put(point.getType(), 1);
}
}
}
i++;
if (i >= k)
break;
}
return map;
}
public static ArrayList<Point> creatDataSet(){
Point point1 = new Point(1, 1.0, 1.1, "A");
Point point2 = new Point(2, 1.0, 1.0, "A");
Point point3 = new Point(3, 1.0, 1.2, "A");
Point point4 = new Point(4, 0, 0, "B");
Point point5 = new Point(5, 0, 0.1, "B");
Point point6 = new Point(6, 0, 0.2, "B");
ArrayList<Point>dataList = new ArrayList<Point>();
dataList.add(point1);
dataList.add(point2);
dataList.add(point3);
dataList.add(point4);
dataList.add(point5);
dataList.add(point6);
return dataList;
}
}
類中涉及到的Point類,Distance類,比較裁判CompareClass類如下:
Point類
public class Point {
private long id;
private double x;
private double y;
private String type;
public Point(long id,double x, double y) {
this.x =x;
this.y =y;
this.id =id;
}
public Point(long id,double x, double y, String type) {
this.x =x;
this.y =y;
this.type= type;
this.id =id;
}
//get、set方法省略
}
Distance類
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
public class Distance { // 已知點id private long id; // 未知點id private long nid; // 二者之間的距離 private double disatance; public Distance(long id, long nid, double disatance) { this.id = id; this.nid = nid; this.disatance = disatance; } //get、set方法省略 } |
比較器CompareClass類
1 2 3 4 5 6 7 8 9 |
import java.util.Comparator; //比較器類 public class CompareClass implements Comparator<Distance>{ public int compare(Distance d1, Distance d2) { return d1.getDisatance()>d2.getDisatance()?20 : -1; } } |
其中的計算Map<String,Double> typeAndDistance按照distance進行排序,也就是按照map的value進行排序。思路也可以用如下方法:
1. public class Testing {
2.
3. public static void main(String[] args) {
4.
5. HashMap<String,Double> map = new HashMap<String,Double>();
6. ValueComparator bvc = new ValueComparator(map);
7. TreeMap<String,Double> sorted_map = new TreeMap<String,Double>(bvc);
8.
9. map.put("A",99.5);
10. map.put("B",67.4);
11. map.put("C",67.4);
12. map.put("D",67.3);
13.
14. System.out.println("unsorted map: "+map);
15.
16. sorted_map.putAll(map);
17.
18. System.out.println("results: "+sorted_map);
19. }
20. }
21.
22. class ValueComparator implements Comparator<String> {
23.
24. Map<String, Double> base;
25. public ValueComparator(Map<String, Double> base) {
26. this.base = base;
27. }
28.
29. // Note: this comparator imposes orderings that are inconsistent with equals.
30. public int compare(String a, String b) {
31. if (base.get(a) >= base.get(b)) {
32. return -1;
33. } else {
34. return 1;
35. } // returning 0 would merge keys
36. }
37. }