CIFAR-10驱动的KNN分类器

本文介绍了使用近邻算法对CIFAR-10数据集进行分类的过程。CIFAR-10包含50000张训练图片和10000张测试图片,每张图片尺寸为32x32像素。通过计算训练样本与测试样本之间的L1距离来预测类别,尽管准确率仅为24.92%,但展示了近邻算法的基本原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


先读取CIFAR-10的数据集,CIFAR的数据字典包含了50000张图片,每张图片是32x32的的三通道彩色图像,所以CIFAR-10的训练集是有50000个32x32x3=3072的向量组成。 (50000,3072)的矩阵构成了训练图片,训练集中有包含了50000个label。测试集是10000张图片,10000个label。训练集分为5个batch,在读取数据时,将5个batch数据读入到一个50000X3072的训练矩阵中,将对应的标签读入到1X10000的数组中。

这里的近邻算法实际上就是将所有的训练数据都保存下来,然后在预测时让所有的训练数据和测试的数据求L1距离(曼哈顿距离,绝对值之和),将差异最小的标签记录下来作为预测到图片的标签。
算法效率低下,预测时间太长。


import numpy as np
import pickle

'''
输入训练集及测试集
'''
file_path = "E:/cifar-10-python/cifar-10-batches-py/"
'''
拆包数据集
'''
import numpy as np

class NearestNeighbor(object):
  def __init__(self):
    pass

  def train(self, X, y):
    """ X is N x D where each row is an example. Y is 1-dimension of size N """
    # the nearest neighbor classifier simply remembers all the training data
    self.Xtr = X
    self.ytr = y

  def predict(self, X):
    """ X is N x D where each row is an example we wish to predict label for """
    num_test = X.shape[0]
    # lets make sure that the output type matches the input type
    Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

    # loop over all test rows
    for i in range(num_test):
      # find the nearest training image to the i'th test image
      # using the L1 distance (sum of absolute value differences)
      distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
      min_index = np.argmin(distances) # get the index with smallest distance
      Ypred[i] = self.ytr[min_index] # predict the label of the nearest example

    return Ypred

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

'''
加载数据集
'''
def load_CIFAR10(file):
    # dictTrain1 = unpickle(file + "data_batch_1")
    # dataTrain1 = dictTrain1['data']
    # labelTrain1 = dictTrain1['labels']
    #
    # dictTrain2 = unpickle(file + "data_batch_2")
    # dataTrain2 = dictTrain2['data']
    # labelTrain2 = dictTrain2['labels']
    #
    # dictTrain3 = unpickle(file + "data_batch_3")
    # dataTrain3 = dictTrain3['data']
    # labelTrain3 = dictTrain3['labels']
    #
    # dictTrain4 = unpickle(file + "data_batch_4")
    # dataTrain4 = dictTrain4['data']
    # labelTrain4 = dictTrain4['labels']
    #
    # dictTrain5 = unpickle(file + "data_batch_5")
    # dataTrain5 = dictTrain5['data']
    # labelTrain5 = dictTrain5['labels']

    # dataTrain = np.vstack([dataTrain1, dataTrain2, dataTrain3, dataTrain4, dataTrain5])
    # labelTrain = np.concatenate([labelTrain1, labelTrain2, labelTrain3, labelTrain4, labelTrain5])

    dictTrain = unpickle(file + "data_batch_1")
    dataTrain = dictTrain['data']
    labelTrain = dictTrain['labels']
    for i in range(2,6):
        dictTrain = unpickle(file+"data_batch_"+str(i))
        dataTrain = np.vstack([dataTrain, dictTrain['data']])
        labelTrain = np.hstack([labelTrain, dictTrain['labels']])

    dictTest = unpickle(file + "test_batch")
    dataTest = dictTest['data']
    labelTest = dictTest['labels']
    labelTest = np.array(labelTest)

    return dataTrain, labelTrain, dataTest, labelTest
dataTrain, labelTrain, dataTest, labelTest = load_CIFAR10(file_path)

print(dataTrain.shape)
print(type(labelTrain))
print(dataTest.shape)
print(len(labelTest))


nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(dataTrain[:50000, :], labelTrain[:50000]) # train the classifier on the training images and labels
labelTest_Predict = nn.predict(dataTest[:10000, :]) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print ('accuracy: %f' % ( np.mean(labelTest_Predict == labelTest[:10000]) ))

程序运行结果图

准确率为24.92%相当低,预测时间及其长

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值