Cs231n-2017课程作业(assignment 1)之KNN
KNN算法理解:(百度百科)
KNN也即k-Nearest Neighbor,K(最)近邻分类算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也即是K个邻居,可根据一定的距离计算公式寻找邻居,如曼哈顿距离,欧式距离等), 这K个实例的多数属于某个类,就把该输入实例分类到这个类中。
更形象的理解如下:图中有两个不同的类,分别是蓝色的正方形所属的一类和红色的三角所属的一类,我们的目的是预测出绿色的圆属于他们中的那一类?
- 如果K=3,绿色圆点的最近的3个邻居是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类
- 如果K=5,绿色圆点的最近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类
于此我们看到,当无法判定当前待分类点是从属于已知分类中的哪一类时,我们可以依据统计学的理论看它所处的位置特征,衡量它周围邻居的权重,而把它归为(或分配)到权重更大的那一类。这就是K近邻算法的核心思想。
KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN 算法本身简单有效,它是一种 lazy-learning 算法,分类器不需要使用训练集进行训练,训练时间复杂度为0。KNN 分类的计算复杂度和训练集中的文档数目成正比,也就是说,如果训练集中文档总数为 n,那么 KNN 的分类时间复杂度为O(n)。
参考:
作业代码下载:https://github.com/Burton2000/CS231n-2017
参考理解博客:https://blog.youkuaiyun.com/u014485485/article/details/79433514
环境配置:
Window7 64位 + Jupyter notebook (通过anaconda安装,python3.6)
数据下载:
下载CIFAR-10数据库:http://www.cs.toronto.edu/~kriz/cifar.html,下载CIFAR-10 python version解压到cs231n/datasets目录下。
开始做作业:
点开 knn.ipynb对每一个cell,shift+enter运行。
首先是加载数据