1. 程式人生 > >Hadoop/MapReduce 及 Spark KNN分類演算法實現

Hadoop/MapReduce 及 Spark KNN分類演算法實現

KNN
假如有一群已知分類的點集:
//S.txt
100;c1;1.0,1.0
101;c1;1.1,1.2
102;c1;1.2,1.0
103;c1;1.6,1.5
104;c1;1.3,1.7
105;c1;2.0,2.1
106;c1;2.0,2.2
107;c1;2.3,2.3
208;c2;9.0,9.0
209;c2;9.1,9.2
210;c2;9.2,9.0
211;c2;10.6,10.5
212;c2;10.3,10.7
213;c2;9.6,9.1
214;c2;9.4,10.4
215;c2;10.3,10.3
300;c3;10.0,1.0
301;c3;10.1,1.2
302;c3;10.2,1.0
303;c3;10.6,1.5
304;c3;10.3,1.7
305;c3;1.0,2.1
306;c3;10.0,2.2
307;c3;10.3,2.3

和未知分類的點集
//R.txt 1000;3.0,3.0 1001;10.1,3.2 1003;2.7,2.7 1004;5.0,5.0 1005;13.1,2.2 1006;12.7,12.7 如何為R中的每一個點找到它合適的分類呢? KNNK鄰近)演算法: 1)確定KK的選擇取決於具體的資料和專案需求) 2)計算新輸入,如【1000;3.0,3.0】與所有訓練資料之間的距離(與K一樣,距離函式的選擇也取決於資料的型別) 3)對距離排序,並根據前K個最小距離確定K個鄰近。 4)蒐集這些鄰近所屬的類別 5)根據多數投票確定類別 通俗來說 有一群土豪:土豪1,土豪2,土豪3,土豪4... 有一群屌絲,屌絲1,屌絲2,屌絲3,屌絲4... 現在來了一個人,如何判斷這個人是屌絲還是土豪呢?
先計算這個人和所有土豪以及所有屌絲的距離(存款、房產等),然後將這些距離按從小到大的距離排列: d1 < d2 < d3 < ... 然後統計這些距離是這個人跟誰比較得來的,例如: d1=distance<person,土豪32>,則給這個人投一土豪票 d2=distance<person,屌絲100>,則給這個人投一屌絲票 d3=distance<person,土豪1>,則給這個人投一土豪票 ... 最後看那一類的票數最多,如果土豪票票數最多,則將這個人分類為土豪。 MapReduce實現思路: map執行之前(setup階段),將已分類檔案
S.txt快取到記憶體中 map階段:每次從R中讀入一行,再遍歷S的每一行,兩這行對應的點求距離,並生成(distance,classfication),如(3,土豪),R這一行與土豪的距離為3map輸出:(rID,(distance,classfication)) reduce輸入:(rID,[(distance,classfication),(distance,classfication),(distance,classfication)....]) reduce:由於reduce輸入的[(distance,classfication),(distance,classfication),(distance,classfication)....]是無序的 因此需要對[(distance,classfication),(distance,classfication),(distance,classfication)....]進行排序 由於這個集合可能很大以至於記憶體中無法存放,因此需要另想辦法,使這個集合到達reduce輸入時就是有序的 關鍵技術:組合鍵 二次排序 map輸出由自然鍵rID改為(rID,distance),通過自定義分割槽器和分組比較器使map的輸出按rID分割槽分組,同時能按distance排序

package cjknn;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Partitioner;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;

import edu.umd.cloud9.io.pair.PairOfFloatString;
import edu.umd.cloud9.io.pair.PairOfStringFloat;

public class CJKNN extends Configured implements Tool  {
    public static final int DIMENS = 2;//點的維數
    public static final int K = 6;
    public static ArrayList<Float> getVectorFromStr(String str)
    {
        String[] vectorStr = str.split(",");
        ArrayList<Float> vector = new ArrayList<Float>();
        for(int i=0;i<vectorStr.length && i < DIMENS; i++)
        {
            vector.add(Float.valueOf(vectorStr[i]));
        }
        return vector;
    }
    
    public static class CJKNNMapper extends Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString>
    {
        PairOfStringFloat outputKey = new PairOfStringFloat();
        PairOfFloatString outputValue = new PairOfFloatString();
        List<String> S = new ArrayList<String>();
        @Override
        protected void setup(
                Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString>.Context context)
                throws IOException, InterruptedException {
            //get S from cache
            FileReader fr = new FileReader("S");
            BufferedReader br = new BufferedReader(fr);
            String line = null;
            while((line = br.readLine()) != null)
            {
                S.add(line);
            }
            fr.close();
            br.close();
        }
        @Override
        protected void map(
                LongWritable key,
                Text value,
                Mapper<LongWritable, Text, PairOfStringFloat, PairOfFloatString>.Context context)
                throws IOException, InterruptedException {
            String rID = value.toString().split(";")[0];
            String rVectorStr = value.toString().split(";")[1];
            ArrayList<Float> rVector = getVectorFromStr(rVectorStr);
           for(String s : S)
           {
               ArrayList<Float> sVector = getVectorFromStr(s.split(";")[2]);
               float distance = calculateDistance(rVector,sVector);
               outputKey.set(rID, distance);
               outputValue.set(distance, s.split(";")[1]);
               context.write(outputKey, outputValue);
           }
        }
        
        private float calculateDistance(ArrayList<Float> rVector,
                ArrayList<Float> sVector) {
            double sum = 0.0;
            for(int i=0;i<rVector.size() && i < DIMENS;i++)
            {
                sum += Math.pow((rVector.get(i) - sVector.get(i)), 2);
            }
            return (float) Math.sqrt(sum);
        }
    }
    
    public static class CJGroupingComparator extends WritableComparator {
        public CJGroupingComparator() 
        {
            super(PairOfStringFloat.class, true);
        }

     @SuppressWarnings("rawtypes")
    @Override
     public int compare(WritableComparable wc1, WritableComparable wc2) {
         PairOfStringFloat pair = (PairOfStringFloat) wc1;
         PairOfStringFloat pair2 = (PairOfStringFloat) wc2;
         int result = pair.getLeftElement().compareTo(pair2.getLeftElement());
         return -result;
     }
 }
    
    /***
     * 定製分割槽器
     * 分割槽器會根據對映器的輸出鍵來決定哪個對映器的輸出傳送到哪個規約器。為此我們需要定義兩個外掛類
     * 首先需要一個定製分割槽器控制哪個規約器處理哪些鍵,另外還要定義一個定製比較器對規約器值排序。
     * 這個定製分割槽器可以確保具有相同鍵(自然鍵,而不是包含溫度值的組合鍵)的所有資料都發送給同一個規約器。
     * 定製比較器會完成排序,保證一旦資料到達規約器,就會按自然鍵對資料分組。
     * @author chenjie
     *
     */
    public class CJPartitioner 
       extends Partitioner<PairOfStringFloat, Text> {

        @Override
        public int getPartition(PairOfStringFloat pair, 
                                Text text, 
                                int numberOfPartitions) {
            // make sure that partitions are non-negative
            return Math.abs(pair.getLeftElement().hashCode() % numberOfPartitions);
        }
    }
    
    public static class CJKNNReducer extends Reducer<PairOfStringFloat, PairOfFloatString, Text, Text>
    {
        @Override
        protected void reduce(
                PairOfStringFloat key,
                Iterable<PairOfFloatString> values,
                Context context)
                throws IOException, InterruptedException {
            System.out.println("key= " + key);
            System.out.println("values:");
            Map<String,Integer> map = new HashMap<String,Integer>();
            int count = 0;
            Iterator<PairOfFloatString> iterator = values.iterator();
            while(iterator.hasNext())
            {
                PairOfFloatString value = iterator.next();
                System.out.println(value); 
                String sClassificationID = value.getRightElement();
                Integer times = map.get(sClassificationID);
                if (times== null )
                {
                    map.put(sClassificationID, 1);
                }
                else
                {
                    map.put(sClassificationID, times+1);
                }
                count ++;
                if(count >= K)
                    break;
            }
           int max = 0;
           String maxSClassificationID = "";
           System.out.println("map:");
           for(Map.Entry<String, Integer> entry : map.entrySet())
           {
               System.out.println(entry);
               if(entry.getValue() > max)
               {
                   max = entry.getValue();
                   maxSClassificationID = entry.getKey();
               }
           }
           context.write(new Text(key.getLeftElement()), new Text(maxSClassificationID));
        }
    }


    public static void main(String[] args) throws Exception
    {
        args = new String[2];
        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/R.txt";
        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJKNN";;
        int jobStatus = submitJob(args);
        System.exit(jobStatus);
    }
    
    public static int submitJob(String[] args) throws Exception {
        int jobStatus = ToolRunner.run(new CJKNN(), args);
        return jobStatus;
    }

    @SuppressWarnings("deprecation")
    @Override
    public int run(String[] args) throws Exception {
        Configuration conf = getConf();
        Job job = new Job(conf);
        job.setJobName("KNN");

        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        
        job.setOutputKeyClass(PairOfStringFloat.class);       
        job.setOutputValueClass(PairOfFloatString.class);      
       
        
        job.setMapperClass(CJKNNMapper.class);
        job.setReducerClass(CJKNNReducer.class);

        FileInputFormat.setInputPaths(job, new Path(args[0]));
        FileOutputFormat.setOutputPath(job, new Path(args[1]));

        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/S.txt" + "#S"));
        job.setPartitionerClass(CJPartitioner.class);
        job.setGroupingComparatorClass(CJGroupingComparator.class);
        
        FileSystem fs = FileSystem.get(conf);
        Path outPath = new Path(args[1]);
        if(fs.exists(outPath))
        {
            fs.delete(outPath, true);
        }
        
        boolean status = job.waitForCompletion(true);
        return status ? 0 : 1;
    }
    
    
}


Spark解決方案:

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
object KNN {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("kNN").setMaster("local")
    val sc = new SparkContext(sparkConf)
    val k = 4//
val d = 2//向量維數
val inputDatasetR = "file:///media/chenjie/0009418200012FF3/ubuntu/R.txt"
val inputDatasetS = "file:///media/chenjie/0009418200012FF3/ubuntu/S.txt"
val output = "file:///media/chenjie/0009418200012FF3/ubuntu/KNN"
val broadcastK = sc.broadcast(k)
    val broadcastD = sc.broadcast(d)

    val R = sc.textFile(inputDatasetR)
    /*
    1000;3.0,3.0
    1001;10.1,3.2
    1003;2.7,2.7
    1004;5.0,5.0
    1005;13.1,2.2
    1006;12.7,12.7
     */
val S = sc.textFile(inputDatasetS)
    /*
    100;c1;1.0,1.0
    101;c1;1.1,1.2
    102;c1;1.2,1.0
    103;c1;1.6,1.5
    104;c1;1.3,1.7
    105;c1;2.0,2.1
    106;c1;2.0,2.2
    107;c1;2.3,2.3
    208;c2;9.0,9.0
    209;c2;9.1,9.2
    210;c2;9.2,9.0
    211;c2;10.6,10.5
    212;c2;10.3,10.7
    213;c2;9.6,9.1
    214;c2;9.4,10.4
    215;c2;10.3,10.3
    300;c3;10.0,1.0
    301;c3;10.1,1.2
    302;c3;10.2,1.0
    303;c3;10.6,1.5
    304;c3;10.3,1.7
    305;c3;1.0,2.1
    306;c3;10.0,2.2
    307;c3;10.3,2.3
     */
    /**
      * 計算兩點間的距離
*
      * @param rAsString as r1,r2, ..., rd
      * @param sAsString as s1,s2, ..., sd
      * @param d 維數
*/
def calculateDistance(rAsString: String, sAsString: String, d: Int): Double = {
      val r = rAsString.split(",").map(_.toDouble)
      val s = sAsString.split(",").map(_.toDouble)
      if (r.length != d || s.length != d) Double.NaN else {
        math.sqrt((r, s).zipped.take(d).map { case (ri, si) => math.pow((ri - si), 2) }.reduce(_ + _))
      }
    }

    val cart = R cartesian S//笛卡爾積
/*
    (1000;3.0,3.0,100;c1;1.0,1.0)
    (1000;3.0,3.0,101;c1;1.1,1.2)
    (1000;3.0,3.0,102;c1;1.2,1.0)
    (1000;3.0,3.0,103;c1;1.6,1.5)
    ...
     */
val knnMapped = cart.map(cartRecord => {//(1000;3.0,3.0,100;c1;1.0,1.0)
val rRecord = cartRecord._1//1000;3.0,3.0
val sRecord = cartRecord._2//100;c1;1.0,1.0
val rTokens = rRecord.split(";")//(1000 3.0,3.0)
val rRecordID = rTokens(0)//1000
val r = rTokens(1) // 3.0,3.0
val sTokens = sRecord.split(";")//(100 c1 1.0,1.0)
val sClassificationID = sTokens(1)//c1
val s = sTokens(2) // 1.0,1.0
val distance = calculateDistance(r, s, broadcastD.value)//sqrt((3-1)^2+(3-1)^2)=2.8284
(rRecordID, (distance, sClassificationID))//(1000,(2.8284,c1))
})
    // note that groupByKey() provides an expensive solution
    // [you must have enough memory/RAM to hold all values for
    // a given key -- otherwise you might get OOM error], but
    // combineByKey() and reduceByKey() will give a better
    // scale-out performance
val knnGrouped = knnMapped.groupByKey()
    /*
    (1005,CompactBuffer((12.159358535712318,c1), (12.041594578792296,c1), (11.960351165413163,c1), (11.52128465059344,c1), (11.81058846967415,c1), (11.10045044131093,c1), (11.1,c1), (10.800462953040487,c1), (7.940403012442126,c2), (8.06225774829855,c2), (7.8390050389064045,c2), (8.668333173107735,c2), (8.94930164873215,c2), (7.736924453553879,c2), (8.996110270555825,c2), (8.570297544426332,c2), (3.3241540277189316,c3), (3.1622776601683795,c3), (3.1384709652950433,c3), (2.596150997149434,c3), (2.8442925306655775,c3), (12.100413216084812,c3), (3.0999999999999996,c3), (2.801785145224379,c3)))
    (1001,CompactBuffer((9.362157870918434,c1), (9.219544457292887,c1), (9.167878707749137,c1), (8.668333173107735,c1), (8.926925562588723,c1), (8.174350127074323,c1), (8.161494961096281,c1), (7.85175139698144,c1), (5.903388857258177,c2), (6.0827625302982185,c2), (5.869412236331676,c2), (7.3171032519706865,c2), (7.502666192761076,c2), (5.921148537234984,c2), (7.23394774656273,c2), (7.102816342831906,c2), (2.202271554554524,c3), (2.0,c3), (2.202271554554524,c3), (1.7720045146669352,c3), (1.513274595042156,c3), (9.166242414424788,c3), (1.004987562112089,c3), (0.9219544457292893,c3)))
    (1000,CompactBuffer((2.8284271247461903,c1), (2.6172504656604803,c1), (2.6907248094147422,c1), (2.0518284528683193,c1), (2.1400934559032696,c1), (1.345362404707371,c1), (1.2806248474865696,c1), (0.9899494936611668,c1), (8.48528137423857,c2), (8.697700845625812,c2), (8.627861844049196,c2), (10.67754653466797,c2), (10.61037228376083,c2), (8.987213138676527,c2), (9.783659846908007,c2), (10.323759005323595,c2), (7.280109889280518,c3), (7.3246160308919945,c3), (7.472616676907761,c3), (7.746612162745725,c3), (7.414849964766652,c3), (2.1931712199461306,c3), (7.045565981523415,c3), (7.333484846919642,c3)))
    (1004,CompactBuffer((5.656854249492381,c1), (5.445181356024793,c1), (5.517245689653488,c1), (4.879549159502341,c1), (4.957822102496216,c1), (4.172529209005013,c1), (4.1036569057366385,c1), (3.818376618407357,c1), (5.656854249492381,c2), (5.869412236331675,c2), (5.8,c2), (7.849203781276162,c2), (7.783315488916019,c2), (6.161980201201558,c2), (6.9656299069072,c2), (7.495331880577405,c2), (6.4031242374328485,c3), (6.360031446463138,c3), (6.56048778674269,c3), (6.603786792439623,c3), (6.243396511515186,c3), (4.94064773081425,c3), (5.730619512757761,c3), (5.948108943185221,c3)))
    (1006,CompactBuffer((16.54629867976521,c1), (16.33431969810803,c1), (16.405486887014355,c1), (15.76863976378432,c1), (15.84171707865028,c1), (15.06154042586614,c1), (14.991330828181999,c1), (14.707821048680186,c1), (5.23259018078045,c2), (5.020956084253276,c2), (5.093132631298737,c2), (3.0413812651491092,c2), (3.124099870362661,c2), (4.75078940808788,c2), (4.0224370722237515,c2), (3.3941125496954263,c2), (12.0074976577137,c3), (11.79025020938911,c3), (11.964113005150026,c3), (11.395174417269795,c3), (11.258774356030056,c3), (15.787653403846944,c3), (10.84158659975559,c3), (10.673331251301065,c3)))
    (1003,CompactBuffer((2.4041630560342617,c1), (2.193171219946131,c1), (2.267156809750927,c1), (1.6278820596099708,c1), (1.7204650534085255,c1), (0.9219544457292889,c1), (0.8602325267042628,c1), (0.5656854249492386,c1), (8.909545442950499,c2), (9.121951545584968,c2), (9.052071586106685,c2), (11.101801655587257,c2), (11.034491379306976,c2), (9.411163583744573,c2), (10.206860437960342,c2), (10.748023074035522,c2), (7.495331880577404,c3), (7.5504966724050675,c3), (7.690253571892151,c3), (7.99061950038919,c3), (7.66550715869472,c3), (1.8027756377319948,c3), (7.3171032519706865,c3), (7.610519036176179,c3)))
     */
val knnOutput = knnGrouped.mapValues(itr => {
      //itr.toList.sortBy(_._1).foreach(println)
      /*
      (2.596150997149434,c3)
      (2.801785145224379,c3)
      (2.8442925306655775,c3)
      (3.0999999999999996,c3)
      (3.1384709652950433,c3)
      (3.1622776601683795,c3)
      (3.3241540277189316,c3)
      (7.736924453553879,c2)
      (7.8390050389064045,c2)
      (7.940403012442126,c2)
      (8.06225774829855,c2)
      (8.570297544426332,c2)
      (8.668333173107735,c2)
      (8.94930164873215,c2)
      (8.996110270555825,c2)
      (10.800462953040487,c1)
      */
val nearestK = itr.toList.sortBy(_._1).take(broadcastK.value)
      /*
      (2.596150997149434,c3)
      (2.801785145224379,c3)
      (2.8442925306655775,c3)
      (3.0999999999999996,c3)
       */
      //nearestK.map(f => (f._2, 1)).foreach(println)
      /*
      (c3,1)
      (c3,1)
      (c3,1)
      (c3,1)
       */
      //nearestK.map(f => (f._2, 1)).groupBy(_._1)
      //(c3,List((c3,1), (c3,1), (c3,1), (c3,1)))
val majority = nearestK.map(f => (f._2, 1)).groupBy(_._1).mapValues(list => {
        val (stringList, intlist) = list.unzip
        intlist.sum
      })
      //(c3,4)
majority.maxBy(_._2)._1
      //c3
})
    //(1005,c3)
knnOutput.foreach(println)
    knnOutput.saveAsTextFile(output)
    sc.stop()
  }
}

輸出結果:

1000c1
1001c3
1003c1
1004c1
1005c3
1006c2