sklearn之分类算法与手写数字识别
sklearn是Python的一个机器学习的库,它有比较完整的监督学习与非监督学习的模型。本文将使用sklearn库里的分类模型来对手写数字(MNIST)做分类实践。
- 数据获取
- 数据读取与存储形式
- sklearn分类模型
- 代码实现与结果
数据获取
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
数据读取与存储形式
将下载好的数据解压带代码目录下即可。
数据文件时二进制格式的,所以要按字节读取。代码如下:
import struct,os
from array import array as pyarray
from numpy import append, array, int8, uint8, zeros
def load_mnist(image_file, label_file, path="."):
digits=np.arange(10)
fname_image = os.path.join(path, image_file)
fname_label = os.path.join(path, label_file)
flbl = open(fname_label, 'rb')
magic_nr, size = struct.unpack(">II", flbl.read(8))
lbl = pyarray("b", flbl.read())
flbl.close()
fimg = open(fname_image, 'rb')
magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = pyarray("B", fimg.read())
fimg.close()
ind = [ k for k in range(size) if lbl[k] in digits ]
N = len(ind)
images = zeros((N,