1. 程式人生 > >資料探勘筆記-分類-KNN-原理與簡單實現

資料探勘筆記-分類-KNN-原理與簡單實現

public class KNNClassifier {
	
	private static void configureJob(Job job) {
		job.setJarByClass(KNNClassifier.class);
		
		job.setMapperClass(KNNMapper.class);
		job.setMapOutputKeyClass(Text.class);
		job.setMapOutputValueClass(PointWritable.class);
		
		job.setReducerClass(KNNReducer.class);
		job.setOutputKeyClass(Text.class);
		job.setOutputValueClass(Text.class);
		
		job.setInputFormatClass(TextInputFormat.class);
		job.setOutputFormatClass(TextOutputFormat.class);
	}
	
	public static void main(String[] args) {
		long start = System.currentTimeMillis();
		Configuration configuration = new Configuration();
		try {
			String[] inputArgs = new GenericOptionsParser(
						configuration, args).getRemainingArgs();
			if (inputArgs.length != 4) {
				System.out.println("error, please input three path.");
				System.out.println("1 train set path.");
				System.out.println("2 test set path.");
				System.out.println("3 output path.");
				System.out.println("4 k value.");
				System.exit(2);
			}
			DistributedCache.addCacheFile(new Path(inputArgs[0]).toUri(), configuration);
			
			configuration.set("k", inputArgs[3]);
			Job job = new Job(configuration, "KNN Classifier");
			
			FileInputFormat.setInputPaths(job, new Path(inputArgs[1]));
			FileOutputFormat.setOutputPath(job, new Path(inputArgs[2]));
			
			configureJob(job);
			
			System.out.println(job.waitForCompletion(true) ? 0 : 1);
			long end = System.currentTimeMillis();
			System.out.println("spend time: " + (end - start) / 1000);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

}

class KNNMapper extends Mapper<LongWritable, Text, Text, PointWritable> {
	
	private List<Point> trainPoints = new ArrayList<Point>();
	
	@Override
	protected void setup(Context context) throws IOException, InterruptedException {
		super.setup(context);
		Configuration conf = context.getConfiguration();
		FileSystem fs = FileSystem.get(conf);
		URI[] uris = DistributedCache.getCacheFiles(conf);
		Path[] paths = HDFSUtils.getPathFiles(fs, new Path(uris[0]));
		for(Path path : paths) {
			FSDataInputStream in = fs.open(path);
			BufferedReader reader = new BufferedReader(new InputStreamReader(in));
			String line = reader.readLine();
			while (null != line && !"".equals(line)) {
				String[] datas = line.split(" ");
				trainPoints.add(new Point(Double.parseDouble(datas[0]), 
						Double.parseDouble(datas[1]), datas[2]));
				line = reader.readLine();
			}
			IOUtils.closeQuietly(in);
			IOUtils.closeQuietly(reader);
		}
	}

	@Override
	protected void map(LongWritable key, Text value, Context context)
			throws IOException, InterruptedException {
		String line = value.toString();
		String[] datas = line.split(" ");
		double x = Double.parseDouble(datas[0]);
		double y = Double.parseDouble(datas[1]);
		Point testPoint = new Point(x, y);
		String outputKey = x + "-" + y;
		for (Point trainPoint : trainPoints) {
			double distance = distance(testPoint, trainPoint);
			context.write(new Text(outputKey), new PointWritable(trainPoint, distance));
		}
	}
	
	public double distance(Point point1, Point point2) {
		return Math.sqrt(Math.pow((point1.getX() - point2.getX()), 2)
				+ Math.pow((point1.getY() - point2.getY()), 2));
	}

	@Override
	protected void cleanup(Context context) throws IOException, InterruptedException {
		super.cleanup(context);
	}
}

class KNNReducer extends Reducer<Text, PointWritable, Text, Text> {

	private int k = 0;
	
	@Override
	protected void setup(Context context) throws IOException, InterruptedException {
		super.setup(context);
		Configuration conf = context.getConfiguration();
		k = Integer.parseInt(conf.get("k", "0"));
	}

	@Override
	protected void reduce(Text key, Iterable<PointWritable> values,
			Context context) throws IOException, InterruptedException {
		System.out.println(key);
		List<PointWritable> points = new ArrayList<PointWritable>();
		for (PointWritable point : values) {
			points.add(new PointWritable(point));
		}
		Collections.sort(points, new Comparator<PointWritable>() {
			@Override
			public int compare(PointWritable o1, PointWritable o2) {
				return o1.getDistance().compareTo(o2.getDistance());
			}
		});
		Map<String, Integer> map = new HashMap<String, Integer>();
		k = points.size() < k ? points.size() : k;
		for (int i = 0; i < k; i++) {
			PointWritable point = points.get(i);
			String category = point.getCategory().toString();
			Integer count = map.get(category);
			map.put(category, null == count ? 1 : count + 1);
		}
		List<Map.Entry<String, Integer>> list = 
				new ArrayList<Map.Entry<String, Integer>>(map.entrySet());
		Collections.sort(list, new Comparator<Map.Entry<String, Integer>>(){
			@Override
			public int compare(Entry<String, Integer> o1,
					Entry<String, Integer> o2) {
				return o2.getValue().compareTo(o1.getValue());
			}
		});
		context.write(key, new Text(list.get(0).getKey()));
	}

	@Override
	protected void cleanup(Context context) throws IOException,
			InterruptedException {
		super.cleanup(context);
	}

}