1. 程式人生 > >KNN(最鄰近值演算法) scala實現

KNN(最鄰近值演算法) scala實現

最鄰近值演算法實現

工程目錄結構

這裡寫圖片描述

程式碼

訓練模型

package com.knn.model

/**
  * 訓練資料模型
  *
  * @param aA 資料a
  * @param bA 資料b
  * @param typeA 型別
  */
class KNNModel(aA:Double,bA:Double,typeA:String) {
  var a:Double = aA
  var b:Double = bA
  var resType: String = typeA
  //距離
  var distince:Double = 0
}

核心演算法程式碼

package com.knn.core

import com.knn.model.KNNModel

import scala.collection.immutable.ListMap

/**
  * 最鄰近演算法核心演算法
  */
class KNN_Core {
//  val knnModel = new KNNModel(null,null,null,null,null);

  /**
    * 對訓練資料進行升序排序(根據距離來進行排序)
    * @param knnMOdels
    * @return
    */
  private def sortByDistince(knnMOdels:List[KNNModel]):List[KNNModel] ={
    //進行升序排序
return knnMOdels.sortBy(knn => knn.distince) } /** * 使用歐幾里得度量計算出距離 * @param knnMOdels * @param k */ private def coluaclateDistince(knnMOdels:List[KNNModel],k: KNNModel):Unit = knnMOdels.foreach(n=>{ n.distince = Math.sqrt((k.a-n.a)*(k.a-n.a)+(k.b-n.b)*(k.b-n.b)) }) /** * 獲取距離目標點附近(指定集合大小的範圍內存在最多的資料) * @param
ks * @return */
private def findMostValue(ks:List[KNNModel]):String ={ //找出訓練集中在規定數量中存在最多的類 var resType = "" var typeCountMap:Map[String,Int] = Map() //進行計數 ks.toStream.foreach(k=>{ if (typeCountMap.contains(k.resType)){ typeCountMap+= (k.resType -> (typeCountMap(k.resType)+1)) }else{ typeCountMap+=(k.resType -> 1) } }) //獲取最多數量型別(根據鍵值進行排序) resType = ListMap(typeCountMap.toSeq.sortWith(_._2 >_._2):_*).take(1).keySet.head return resType } def reckonRelize(kns:List[KNNModel],kn:KNNModel,k: Int):String={ //計算距離 coluaclateDistince(kns,kn) //根據距離排序 var knsSort = sortByDistince(kns) //獲取前k個數據 var knss = knsSort.take(k) //獲取k個數據中數量最多的型別 return findMostValue(knss) } }

執行程式碼

package com.knn

import com.knn.core.KNN_Core
import com.knn.model.KNNModel

/**
  * 分割類
  */
object app {
  def main(args: Array[String]): Unit = {
    //資料準備
    var knnModels:List[KNNModel] = List()
    knnModels = knnModels.::(new KNNModel(1.1, 1.1, "A"))
    knnModels = knnModels.::(new KNNModel(1.2, 1.2, "A"))
    knnModels = knnModels.::(new KNNModel(1.1, 1.0, "A"))
    knnModels = knnModels.::(new KNNModel(3.0, 3.1, "B"))
    knnModels = knnModels.::(new KNNModel(3.1, 3.0, "B"))
    knnModels = knnModels.::(new KNNModel(5.4, 6.0, "C"))
    knnModels = knnModels.::(new KNNModel(5.5, 6.3, "C"))
    knnModels = knnModels.::(new KNNModel(6.0, 12.0, "C"))
    knnModels = knnModels.::(new KNNModel(10.0, 12.0, "M"))
    //待預測資料
    var knnModle = new KNNModel(4.0, 3.2, "A")

    var kNN_Core = new KNN_Core
    //演算法實現
    var resType = kNN_Core.reckonRelize(knnModels,knnModle,5)
    println("預測結果",resType)
  }
}

參考資料