#!/usr/bin/env python3 # -*- coding: utf-8 -*- ''' @author toby KNN算法:确定K值(通过交叉验证),找出前K个数(可以优化),对K个数进行统计,找出统计最多的数。 ''' import math import random print 'KNN算法' k=20 def distance(x1, y1, x2, y2): #计算距离,距离在多维空间或者特殊情况采取计算距离策略可能不同 return math.sqrt(math.pow((x1-x2), 2)+math.pow((y1-y2), 2)) def get_middle(start, end, init_data): target1=start target2=end while start<end: if(init_data[target1][3]>init_data[target2][3]): temp=init_data[target2] init_data[target2]=init_data[target1] init_data[target1]=temp start=start+1 end=target2 target1=target2 else: end=end-1; target2=end return target1 def divide(start, end, init_data): if start>=end: return if end<k: return elif start>k: return else: middle=get_middle(start, end, init_data) divide(start, middle-1, init_data) divide(middle+1, end, init_data) def sort(init_data): #排序算法根据快速排序算法中中间值左边小于该值,如果该值在大于k则舍弃掉该值后面的数据。如果该值小于k则保留k值左边的数据 divide(0, len(init_data)-1, init_data) return ; target_x=40 target_y=40 print "目标坐标为:(%s, %s)" % (target_x, target_y) init_data=[] for i in range(1, 100): x=random.randint(1,100) y=random.randint(1,100) label_prefix="label" init_data.append([label_prefix+str(random.randint(20, 30)), x, y, -1]) for node in init_data: #计算距离 node[3]=distance(node[1], node[2], target_x, target_y) sort(init_data) result_data={} while k>=0: #统计数据 k=k-1 if init_data[k][0] in result_data: result_data[init_data[k][0]]=result_data[init_data[k][0]]+1 else: result_data[init_data[k][0]]=1 max=0 for node in result_data: if(result_data[node]>max): max=result_data[node] target_name=node print target_name+"为target该归属的类"
KNN算法初步改进版
最新推荐文章于 2025-03-14 14:09:11 发布