前言
这是CS231n 图像分类课时的作业,自己编写KNN算法,实现图像分类,该任务涉及L1,L2距离,以及K值超参数的选择
1. KNN算法
KNN算法思想就是将输入的每一个测试样本与所有的训练样本计算距离值,然后选择K个距离最小的候选者,采用少数服从多数的投票方式选择类别
2. 实战部分
2.1 数据集载入
这里选用的数据集是 cifar-10 数据集 http://www.cs.toronto.edu/~kriz/cifar.html
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
import numpy as np
def load_cifar10_batch(filename):
import pickle
with open(filename,'rb') as f:
data=pickle.load(f,encoding='bytes') ##加载二进值文件
x=data[b'data']
y=data[b'labels']
x=x.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
y=np.array(y)
return x,y
def load_cifar10(filename):
'''加载全部数据,共有6个batch'''
X=[]
Y=[]
for i in range(1,6):
dir_url=filename+'/'+'data_batch_%s'%i
x,y=load_cifar10_batch(dir_url)
X.append(x)
Y.append(y)
X_train=np.concatenate(X)
Y_train=np.concatenate(Y)
X_test,Y_test=load_cifar10_batch(filename+'/'+'test_batch')
return X_train,Y_train,X_test,Y_test
X_train,Y_train,X_test,Y_test=load_cifar10("/Users/yugui/Downloads/cifar-10-batches-py")
print('训练集大小:',X_train.shape)
print('测试集大小:',X_test.shape)
# 图片展示
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置figure_size尺寸
plt.rcParams['image.interpolation'] = 'nearest' # 设置 interpolation style
plt.rcParams['image.cmap'] = 'gray' # 设置 颜色 style
'''enumerate函数说明:
函数原型:enumerate(sequence, [start=0]) #第二个参数为指定索引
功能:将可循环序列sequence以start开始分别列出序列数据和数据下标
即对一个可遍历的数据对象(如列表、元组或字符串),enumerate会将该数据对象组合为一个索引序列,同时列出数据和数据下标'''
classes = ['plane', 'car', 'bird', 'cat', 'd