原文地址:http://blog.youkuaiyun.com/fighting_one_piece/article/details/39030969
KNN算法又称为k近邻分类(k-nearest neighbor classification)算法。最简单平凡的分类器。KNN算法则是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别。该算法涉及3个主要因素:训练集、距离、k的大小。可用于客户流失预测、欺诈侦测等(更适合于稀有事件的分类问题)。
计算步骤如下:
1)计算距离:给定测试对象,计算它与训练集中的每个对象的距离。距离计算可以选择欧氏距离 、曼哈顿距离、余弦距离等。计算距离之前最好对数据进行规范化处理,以便于更好的计算。
2)寻找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻。K值的选择可以通过若干试验,选取分类误差最小的K值 。
3)判断分类:根据这k个近邻归属的主要类别,来对测试对象分类。判定方式主要是投票决定,少数服从多数,近邻中哪个类别的点最多就分为该类。也可以通过加权投票方法来决定。
优点
简单,易于理解,易于实现,无需估计参数,无需训练
适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型)
特别适合于多分类问题(multi-modal,对象具有多个类别标签),例如根据基因特征来判断其功能分类,KNN比SVM的表现要好
缺点
可解释性较差,无法给出决策树那样的规则。
下面是基于MapReduce简单实现KNN的代码:
- public class KNNClassifier {
- private static void configureJob(Job job) {
- job.setJarByClass(KNNClassifier.class);
- job.setMapperClass(KNNMapper.class);
- job.setMapOutputKeyClass(Text.class);
- job.setMapOutputValueClass(PointWritable.class);
- job.setReducerClass(KNNReducer.class);
- job.setOutputKeyClass(Text.class);
- job.setOutputValueClass(Text.class);
- job.setInputFormatClass(TextInputFormat.class);
- job.setOutputFormatClass(TextOutputFormat.class);
- }
- public static void main(String[] args) {
- long start = System.currentTimeMillis();
- Configuration configuration = new Configuration();
- try {
- String[] inputArgs = new GenericOptionsParser(
- configuration, args).getRemainingArgs();
- if (inputArgs.length != 4) {
- System.out.println("error, please input three path.");
- System.out.println("1 train set path.");
- System.out.println("2 test set path.");
- System.out.println("3 output path.");
- System.out.println("4 k value.");
- System.exit(2);
- }
- DistributedCache.addCacheFile(new Path(inputArgs[0]).toUri(), configuration);
- configuration.set("k", inputArgs[3]);
- Job job = new Job(configuration, "KNN Classifier");
- FileInputFormat.setInputPaths(job, new Path(inputArgs[1]));
- FileOutputFormat.setOutputPath(job, new Path(inputArgs[2]));
- configureJob(job);
- System.out.println(job.waitForCompletion(true) ? 0 : 1);
- long end = System.currentTimeMillis();
- System.out.println("spend time: " + (end - start) / 1000);
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- }
- class KNNMapper extends Mapper<LongWritable, Text, Text, PointWritable> {
- private List<Point> trainPoints = new ArrayList<Point>();
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- FileSystem fs = FileSystem.get(conf);
- URI[] uris = DistributedCache.getCacheFiles(conf);
- Path[] paths = HDFSUtils.getPathFiles(fs, new Path(uris[0]));
- for(Path path : paths) {
- FSDataInputStream in = fs.open(path);
- BufferedReader reader = new BufferedReader(new InputStreamReader(in));
- String line = reader.readLine();
- while (null != line && !"".equals(line)) {
- String[] datas = line.split(" ");
- trainPoints.add(new Point(Double.parseDouble(datas[0]),
- Double.parseDouble(datas[1]), datas[2]));
- line = reader.readLine();
- }
- IOUtils.closeQuietly(in);
- IOUtils.closeQuietly(reader);
- }
- }
- @Override
- protected void map(LongWritable key, Text value, Context context)
- throws IOException, InterruptedException {
- String line = value.toString();
- String[] datas = line.split(" ");
- double x = Double.parseDouble(datas[0]);
- double y = Double.parseDouble(datas[1]);
- Point testPoint = new Point(x, y);
- String outputKey = x + "-" + y;
- for (Point trainPoint : trainPoints) {
- double distance = distance(testPoint, trainPoint);
- context.write(new Text(outputKey), new PointWritable(trainPoint, distance));
- }
- }
- public double distance(Point point1, Point point2) {
- return Math.sqrt(Math.pow((point1.getX() - point2.getX()), 2)
- + Math.pow((point1.getY() - point2.getY()), 2));
- }
- @Override
- protected void cleanup(Context context) throws IOException, InterruptedException {
- super.cleanup(context);
- }
- }
- class KNNReducer extends Reducer<Text, PointWritable, Text, Text> {
- private int k = 0;
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- k = Integer.parseInt(conf.get("k", "0"));
- }
- @Override
- protected void reduce(Text key, Iterable<PointWritable> values,
- Context context) throws IOException, InterruptedException {
- System.out.println(key);
- List<PointWritable> points = new ArrayList<PointWritable>();
- for (PointWritable point : values) {
- points.add(new PointWritable(point));
- }
- Collections.sort(points, new Comparator<PointWritable>() {
- @Override
- public int compare(PointWritable o1, PointWritable o2) {
- return o1.getDistance().compareTo(o2.getDistance());
- }
- });
- Map<String, Integer> map = new HashMap<String, Integer>();
- k = points.size() < k ? points.size() : k;
- for (int i = 0; i < k; i++) {
- PointWritable point = points.get(i);
- String category = point.getCategory().toString();
- Integer count = map.get(category);
- map.put(category, null == count ? 1 : count + 1);
- }
- List<Map.Entry<String, Integer>> list =
- new ArrayList<Map.Entry<String, Integer>>(map.entrySet());
- Collections.sort(list, new Comparator<Map.Entry<String, Integer>>(){
- @Override
- public int compare(Entry<String, Integer> o1,
- Entry<String, Integer> o2) {
- return o2.getValue().compareTo(o1.getValue());
- }
- });
- context.write(key, new Text(list.get(0).getKey()));
- }
- @Override
- protected void cleanup(Context context) throws IOException,
- InterruptedException {
- super.cleanup(context);
- }
- }