1. 程式人生 > >spark 實現K-means演算法

spark 實現K-means演算法

spark 實現K-means演算法

package kmeans;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;.
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.
apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.VoidFunction; import scala.Tuple2; public class kmeans{ static double[][] center = new double[4][2]; //這裡有4箇中心點,為2維 static int[] number = new int[4]; //記錄屬於當前中心點的資料的個數,方便做除法
static double[][] new_center = new double[4][2]; //計算出來的新中心點 public static void main(String[] args) { // 從檔案中讀出中心點,並且放入center陣列中 ArrayList<String> arrayList = new ArrayList<String>(); try { File file = new File("/usr/local/hadoop-2.7.3/centers.txt"); InputStreamReader input = new InputStreamReader(new FileInputStream(file)); BufferedReader bf = new BufferedReader(input); // 按行讀取字串 String str; while ((str = bf.readLine()) != null) { arrayList.add(str); } bf.close(); input.close(); } catch (IOException e) { e.printStackTrace(); } // 對ArrayList中儲存的字串進行處理 for (int i = 0; i < 4; i++) { for (int j = 0; j < 2; j++) { String s = arrayList.get(i).split(",")[j]; center[i][j] = Double.parseDouble(s); } } //System.out.println("center+++" + center[3][1]); SparkConf conf = new SparkConf().setAppName("kmeans").setMaster("local[*]"); JavaSparkContext jsc = new JavaSparkContext(conf); JavaRDD<String> datas = jsc.textFile("spark/input4/k-means.dat"); //從hdfs上讀取data while(true) { for (int i = 0; i< 4;i++) //注意每次迴圈都需要將number[i]變為0 { number[i]=0; } //將data分開,得到key: 屬於某個中心點的序號(0/1/2/3),value: 與該中心點的距離 JavaPairRDD<Integer, Tuple2<Double, Double>> data = datas.mapToPair(new PairFunction<String, Integer,Tuple2<Double, Double>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Integer,Tuple2<Double, Double>> call(String str) throws Exception { final double[][] loc = center; String[] datasplit = str.split(","); double x = Double.parseDouble(datasplit[0]); double y = Double.parseDouble(datasplit[1]); double minDistance = 99999999; int centerIndex = 0; for(int i = 0;i < 4;i++){ double itsDistance = (x-loc[i][0])*(x-loc[i][0])+(y-loc[i][1])*(y-loc[i][1]); if(itsDistance < minDistance){ minDistance = itsDistance; centerIndex = i; } } number[centerIndex]++; //得到屬於4箇中心點的個數 return new Tuple2<Integer,Tuple2<Double, Double>>(centerIndex, new Tuple2<Double,Double>(x,y)); // the center's number & data } }); //得到key: 屬於某個中心點的序號, value:新中心點的座標 JavaPairRDD<Integer, Iterable<Tuple2<Double, Double>>> sum_center = data.groupByKey(); //System.out.println(sum_center.collect()); JavaPairRDD<Integer,Tuple2<Double, Double>> Ncenter = sum_center.mapToPair(new PairFunction<Tuple2<Integer, Iterable<Tuple2<Double, Double>>>,Integer,Tuple2<Double, Double>>() { private static final long serialVersionUID = 1L; @Override public Tuple2<Integer, Tuple2<Double, Double>> call(Tuple2<Integer, Iterable<Tuple2<Double, Double>>> a)throws Exception { //System.out.println("i am here**********new center******"); int sum_x = 0; int sum_y = 0; Iterable<Tuple2<Double, Double>> it = a._2; for(Tuple2<Double, Double> i : it) { sum_x += i._1; sum_y +=i._2; } double average_x = sum_x / number[a._1]; double average_y = sum_y/number[a._1]; //System.out.println("**********new center******"+a._1+" "+average_x+","+average_y); return new Tuple2<Integer,Tuple2<Double,Double>>(a._1,new Tuple2<Double,Double>(average_x,average_y)); } }); //將中心點輸出 Ncenter.foreach(new VoidFunction<Tuple2<Integer,Tuple2<Double,Double>>>() { private static final long serialVersionUID = 1L; @Override public void call(Tuple2<Integer,Tuple2<Double,Double>> t) throws Exception { new_center[t._1][0] = t._2()._1; new_center[t._1][1] = t._2()._2; System.out.println("the new center: "+ t._1+" "+t._2()._1+" , "+t._2()._2); } }); //判斷新的中心點和原來的中心點是否一樣,一樣的話退出迴圈得到結果,不一樣的話繼續迴圈(這裡可以設定一個迭代次數) double distance = 0; for(int i=0;i<4;i++) { distance += (center[i][0]-new_center[i][0])*(center[i][0]-new_center[i][0]) + (center[i][1]-new_center[i][1])*(center[i][1]-new_center[i][1]); } if(distance == 0.0) { //finished for(int j = 0;j<4;j++) { System.out.println("the final center: "+" "+center[j][0]+" , "+center[j][1]); } break; } else { for(int i = 0;i<4;i++) { center[i][0] = new_center[i][0]; center[i][1] = new_center[i][1]; new_center[i][0] = 0; new_center[i][1] = 0; System.out.println("the new center: "+" "+center[i][0]+" , "+center[i][1]); } } } } }

輸入:

1. centers.txt :
	96,826
	606,776	
	474,866
	400,768
  1. data.dat:
    存放所有點的座標存放所有點的座標。