spark 實現K-means演算法
阿新 • • 發佈:2018-11-21
spark 實現K-means演算法
package kmeans;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;.
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function2;
import org. apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
public class kmeans{
static double[][] center = new double[4][2]; //這裡有4箇中心點,為2維
static int[] number = new int[4]; //記錄屬於當前中心點的資料的個數,方便做除法
static double[][] new_center = new double[4][2]; //計算出來的新中心點
public static void main(String[] args) {
// 從檔案中讀出中心點,並且放入center陣列中
ArrayList<String> arrayList = new ArrayList<String>();
try {
File file = new File("/usr/local/hadoop-2.7.3/centers.txt");
InputStreamReader input = new InputStreamReader(new FileInputStream(file));
BufferedReader bf = new BufferedReader(input);
// 按行讀取字串
String str;
while ((str = bf.readLine()) != null) {
arrayList.add(str);
}
bf.close();
input.close();
} catch (IOException e) {
e.printStackTrace();
}
// 對ArrayList中儲存的字串進行處理
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 2; j++) {
String s = arrayList.get(i).split(",")[j];
center[i][j] = Double.parseDouble(s);
}
}
//System.out.println("center+++" + center[3][1]);
SparkConf conf = new SparkConf().setAppName("kmeans").setMaster("local[*]");
JavaSparkContext jsc = new JavaSparkContext(conf);
JavaRDD<String> datas = jsc.textFile("spark/input4/k-means.dat"); //從hdfs上讀取data
while(true) {
for (int i = 0; i< 4;i++) //注意每次迴圈都需要將number[i]變為0
{
number[i]=0;
}
//將data分開,得到key: 屬於某個中心點的序號(0/1/2/3),value: 與該中心點的距離
JavaPairRDD<Integer, Tuple2<Double, Double>> data = datas.mapToPair(new PairFunction<String, Integer,Tuple2<Double, Double>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer,Tuple2<Double, Double>> call(String str) throws Exception {
final double[][] loc = center;
String[] datasplit = str.split(",");
double x = Double.parseDouble(datasplit[0]);
double y = Double.parseDouble(datasplit[1]);
double minDistance = 99999999;
int centerIndex = 0;
for(int i = 0;i < 4;i++){
double itsDistance = (x-loc[i][0])*(x-loc[i][0])+(y-loc[i][1])*(y-loc[i][1]);
if(itsDistance < minDistance){
minDistance = itsDistance;
centerIndex = i;
}
}
number[centerIndex]++; //得到屬於4箇中心點的個數
return new Tuple2<Integer,Tuple2<Double, Double>>(centerIndex, new Tuple2<Double,Double>(x,y));
// the center's number & data
}
});
//得到key: 屬於某個中心點的序號, value:新中心點的座標
JavaPairRDD<Integer, Iterable<Tuple2<Double, Double>>> sum_center = data.groupByKey();
//System.out.println(sum_center.collect());
JavaPairRDD<Integer,Tuple2<Double, Double>> Ncenter = sum_center.mapToPair(new PairFunction<Tuple2<Integer, Iterable<Tuple2<Double, Double>>>,Integer,Tuple2<Double, Double>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<Integer, Tuple2<Double, Double>> call(Tuple2<Integer, Iterable<Tuple2<Double, Double>>> a)throws Exception {
//System.out.println("i am here**********new center******");
int sum_x = 0;
int sum_y = 0;
Iterable<Tuple2<Double, Double>> it = a._2;
for(Tuple2<Double, Double> i : it) {
sum_x += i._1;
sum_y +=i._2;
}
double average_x = sum_x / number[a._1];
double average_y = sum_y/number[a._1];
//System.out.println("**********new center******"+a._1+" "+average_x+","+average_y);
return new Tuple2<Integer,Tuple2<Double,Double>>(a._1,new Tuple2<Double,Double>(average_x,average_y));
}
});
//將中心點輸出
Ncenter.foreach(new VoidFunction<Tuple2<Integer,Tuple2<Double,Double>>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2<Integer,Tuple2<Double,Double>> t) throws Exception {
new_center[t._1][0] = t._2()._1;
new_center[t._1][1] = t._2()._2;
System.out.println("the new center: "+ t._1+" "+t._2()._1+" , "+t._2()._2);
}
});
//判斷新的中心點和原來的中心點是否一樣,一樣的話退出迴圈得到結果,不一樣的話繼續迴圈(這裡可以設定一個迭代次數)
double distance = 0;
for(int i=0;i<4;i++) {
distance += (center[i][0]-new_center[i][0])*(center[i][0]-new_center[i][0]) + (center[i][1]-new_center[i][1])*(center[i][1]-new_center[i][1]);
}
if(distance == 0.0) {
//finished
for(int j = 0;j<4;j++) {
System.out.println("the final center: "+" "+center[j][0]+" , "+center[j][1]);
}
break;
}
else {
for(int i = 0;i<4;i++) {
center[i][0] = new_center[i][0];
center[i][1] = new_center[i][1];
new_center[i][0] = 0;
new_center[i][1] = 0;
System.out.println("the new center: "+" "+center[i][0]+" , "+center[i][1]);
}
}
}
}
}
輸入:
1. centers.txt :
96,826
606,776
474,866
400,768
- data.dat:
存放所有點的座標存放所有點的座標。