目录
一、理论基础
KNN算法的原理已在上一篇博客详细论述,在这里不再展开。在之前的部分,我们实现了基于二维坐标点的分类,这是图像分类中的一个简单示例。现在,我们将进一步扩展这个概念,应用于更复杂的图像数据集,以实现实际的图像分类任务。
1. 图像分类
图像可以根据多种标准和特性进行分类。不同的分类方法反映了图像的不同属性和用途。常见的图像类型有彩色图像、灰度图像和二值图像。
2. 实现思路
图像分类的任务就是预测给定图像属于各类标签的的可能性,根据这些输出,模型将选择概率最高的标签作为预测结果。
3. 图像预处理
在使用算法前,一般都需要先对图像进行预处理,包括:归一化、灰度变换、滤波变换等等。
二、数据集准备
MNIST数据集包含来自大约250个不同人手写的数字,其中一半是美国人口调查局的员工,另一半是美国高中生。这些数字经过归一化和中心对齐处理。
1. 数据集结构
每张图像都是固定的28x28像素大小的灰度图像。并且,每张图像都有一个与之对应的标签,标签是0到9的数字,表示图像中手写数字的实际值。
训练集:包含60,000个样本,这些样本用于训练模型。
测试集:包含10,000个样本,这些样本用于测试模型的性能。
2. 加载数据集
在PyTorch中,可以直接使用torchvision.datasets.MNIST
来加载MNIST手写数字数据集。在这里,设置batch_size = 100,即每次迭代中使用的样本数量为100。需要注意的是:在加载数据前,需要将图像转化为张量,确保数据与PyTorch库的兼容性,使得可以方便地在图像上应用各种算法和操作。
MNIST_dataset_loader.py
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import torchvision.transforms as transforms
#指定每次训练迭代的样本数量
batch_size = 100
transform = transforms.ToTensor() #将图片转化为PyTorch张量
train_dataset = dsets.MNIST(root= '指定路径/pymnist',
train= True,
transform=transforms.ToTensor(),
download=True)
test_dataset = dsets.MNIST(root= '指定路径/pymnist',
train= False,
transform=transforms.ToTensor(),
download=True)
#加载数据
train_loader = torch.utils.data.DataLoader(dataset= train_dataset,
batch_size=batch_size,
shuffle= True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle= True)
通过打印训练和测试数据集的数据和标签的尺寸,输出数据集及其对应标签的维度信息,这对于数据结构的把握非常重要。
print("train data:",train_dataset.train_data.size())
print("train labels:",train_dataset.train_labels.size())
print("test data:",test_dataset.test_data.size())
print("test labels:",test_dataset.test_labels.size())
输出结果为:
train data: torch.Size([60000, 28, 28])
train labels: torch.Size([60000])
test data: torch.Size([10000, 28, 28])
test labels: torch.Size([10000])
3. 图像可视化
从第100个训练图像(因为索引从0开始),并且打印出该图像对应的标签:
MNIST_show.py
import matplotlib.pyplot as plt
import MNIST_dataset_loader
train = MNIST_dataset_loader.train_loader.dataset.train_data[99]
plt.imshow(train, cmap=plt.cm.binary)
plt.show()
print(MNIST_dataset_loader.train_loader.dataset.train_labels[99])
输出结果:
tensor(1)
![]() |
图1 MNIST数据集中第100个训练图像 |