C#實現K-近鄰(KNN)演算法
阿新 • • 發佈:2018-12-06
KNN(k-nearest-neighbor)演算法的思想是找到在輸入新資料時,找到與該資料最接近的k個鄰居,在這k個鄰居中,找到出現次數最多的類別,對其進行歸類。
Iris資料集是常用的分類實驗資料集,由Fisher, 1936收集整理。Iris也稱鳶尾花卉資料集,是一類多重變數分析的資料集。資料集包含150個數據集,分為3類,每類50個數據,每個資料包含4個屬性。可通過花萼長度,花萼寬度,花瓣長度,花瓣寬度4個屬性預測鳶尾花卉屬於(Setosa,Versicolour,Virginica)三個種類中的哪一類。
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApplication2 { public class Iris { // private double sepalLength; public double SepalLength { get { return sepalLength; } set { sepalLength = value; } } // private double sepalWidth; public double SepalWidth { get { return sepalWidth; } set { sepalWidth = value; } } // private double petalLength; public double PetalLength { get { return petalLength; } set { petalLength = value; } } // private double petalWidth; public double PetalWidth { get { return petalLength; } set { petalLength = value; } } // private string species; public string Species { get { return species; } set { species = value; } } } }
using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApplication2 { public class KNN { /// <summary> /// 樣本資料 /// </summary> private List<Iris> sampleList; /// <summary> /// 未分類資料 /// </summary> private List<Iris> unclassifyList; /// <summary> /// K值 /// </summary> private int k; /// <summary> /// 建構函式 /// </summary> /// <param name="sampleList">樣本資料</param> /// <param name="unclassifyList">未分類資料</param> /// <param name="k">k值</param> public KNN(List<Iris> sampleList, List<Iris> unclassifyList, int k) { this.sampleList = sampleList; this.unclassifyList = unclassifyList; this.k = k; } /// <summary> /// 分類 /// </summary> public void Classify() { int sampleCount = sampleList.Count; int unclassifyCount = unclassifyList.Count; // for (int i = 0; i < unclassifyCount; i++) { Tuple<string, double>[] tupleArray = new Tuple<string, double>[sampleCount]; for (int j = 0; j < sampleCount; j++) { double distance = CalculateDistance(sampleList[j], unclassifyList[i]); string species = sampleList[j].Species; tupleArray[j] = Tuple.Create(species, distance); } // IEnumerable<Tuple<string, double>> selector = tupleArray.OrderBy(t => t.Item2).Take(k); Dictionary<string, int> dictionary = new Dictionary<string, int>(); foreach (Tuple<string, double> tuple in selector) { if (dictionary.ContainsKey(tuple.Item1)) { dictionary[tuple.Item1]++; } else { dictionary.Add(tuple.Item1, 1); } } // IEnumerable<KeyValuePair<string, int>> keyValuePair = dictionary.OrderByDescending(t => t.Value).Take(1); foreach (KeyValuePair<string, int> kvp in keyValuePair) { unclassifyList[i].Species = kvp.Key; } // sampleList.Add(unclassifyList[i]); sampleCount++; } } /// <summary> /// 計算距離 /// </summary> /// <param name="sample">樣本資料</param> /// <param name="unclassify">未分類資料</param> /// <returns>兩者歐氏距離</returns> public double CalculateDistance(Iris sample, Iris unclassify) { double delta_SepalLength = unclassify.SepalLength - sample.SepalLength; double delta_SepalWidth = unclassify.SepalWidth - sample.SepalWidth; double delta_PetalLength = unclassify.PetalLength - sample.PetalLength; double delta_PetalWidth = unclassify.PetalWidth - sample.PetalWidth; return Math.Sqrt(delta_SepalLength * delta_SepalLength + delta_SepalWidth * delta_SepalWidth + delta_PetalLength * delta_PetalLength + delta_PetalWidth * delta_PetalWidth); } /// <summary> /// 列印 /// </summary> public void Print(string filePath) { StringBuilder stringBuilder = new StringBuilder(); for (int i = 0; i < sampleList.Count; i++) { Iris iris = sampleList[i]; stringBuilder.AppendLine(i.ToString() + "\t" + iris.SepalLength.ToString() + "\t" + iris.SepalWidth.ToString() + "\t" + iris.PetalLength.ToString() + "\t" + iris.PetalWidth.ToString() + "\t" + iris.Species); } System.IO.FileStream fs = new System.IO.FileStream(filePath, System.IO.FileMode.Create); System.IO.StreamWriter sw = new System.IO.StreamWriter(fs); sw.Write(stringBuilder.ToString()); sw.Flush(); sw.Close(); fs.Close(); fs.Dispose(); } } }
using System; using System.Collections; using System.Collections.Generic; using System.Data; using System.Data.SqlClient; using System.Linq; using System.Text; using System.Threading.Tasks; namespace ConsoleApplication2 { class Program { static void Main(string[] args) { List<Iris> sampleList = GetIrisDataset(AppDomain.CurrentDomain.BaseDirectory + "樣本.txt"); List<Iris> unclassifyList = GetIrisDataset(AppDomain.CurrentDomain.BaseDirectory + "未分類.txt"); KNN tool = new KNN(sampleList, unclassifyList, 5); tool.Classify(); tool.Print(@"C:\Users\DSF\Desktop\t.txt"); Console.WriteLine("OK"); } static List<Iris> GetIrisDataset(string filePath) { System.IO.FileStream fs = new System.IO.FileStream(filePath, System.IO.FileMode.Open); System.IO.StreamReader sr = new System.IO.StreamReader(fs); // List<Iris> list = new List<Iris>(); string readLine = sr.ReadLine(); while (!string.IsNullOrEmpty(readLine)) { string[] splitArray = readLine.Split(' '); Iris iris = new Iris(); iris.SepalLength = Convert.ToDouble(splitArray[1]); iris.SepalWidth = Convert.ToDouble(splitArray[2]); iris.PetalLength = Convert.ToDouble(splitArray[3]); iris.PetalWidth = Convert.ToDouble(splitArray[4]); iris.Species = splitArray[5]; list.Add(iris); readLine = sr.ReadLine(); } // sr.Close(); fs.Close(); fs.Dispose(); return list; } } }