文章目录
MNIST
数据介绍:本章使用MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。每张图像都用其代表的数字标记。这个数据集被广为使用,因此也被称作是机器学习领域的“Hello World”:但凡有人想到了一个新的分类算法,都会想看看在MNIST上的执行结果。因此只要是学习机器学习的人,早晚都要面对MNIST。
# 使用sklearn的函数来获取MNIST数据集
from sklearn.datasets import fetch_openml
import numpy as np
import os
import time
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# 耗时巨大
def sort_by_target(mnist):
reorder_train=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[:60000])]))[:,1]
reorder_test=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[60000:])]))[:,1]
mnist.data[:60000]=mnist.data[reorder_train]
mnist.target[:60000]=mnist.target[reorder_train]
mnist.data[60000:]=mnist.data[reorder_test+60000]
mnist.target[60000:]=mnist.target[reorder_test+60000]
下面这一部分有点耗时,需要一点耐心,如果运行时间过长建议重新运行,之前有一次运行了15分钟才出来,没想明白为什么
a=time.time()
mnist=fetch_openml('mnist_784',version=1,cache=True)
mnist.target=mnist.target.astype(np.int8)
sort_by_target(mnist)
b=time.time()
b-a
运行时间:
41.92971682548523
mnist["data"], mnist["target"]
(array([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]]),
array([0, 0, 0, …, 9, 9, 9], dtype=int8))
mnist.data.shape
(70000, 784)
X,y=mnist["data"],mnist["target"]
X.shape
(70000, 784)
y.shape
(70000,)
28*28
784
# 展示图片
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = mpl.cm.binary,
interpolation="nearest")
plt.axis("off")
some_digit = X[36000]
plot_digit(X[36000].reshape(28,28))

y[36000]
5
# 更好看的图片展示
def plot_digits(instances,images_per_row=10,**options):
size=28
# 每一行有一个
image_pre_row=min(len(instances),images_per_row)
images=[instances.reshape(size,size) for instances in instances]
# 有几行
n_rows=(len(instances)-1) // image_pre_row+1
row_images=[]
n_empty=n_rows*image_pre_row-len(instances)
images.append(np.zeros((size,size*n_empty)))
for row in range(n_rows):
# 每一次添加一行
rimages=images[row*image_pre_row:(row+1)*image_pre_row]
# 对添加的每一行的额图片左右连接
row_images.append(np.concatenate(rimages,axis=1))
# 对添加的每一列图片 上下连接
image=np.concatenate(row_images,axis=0)
plt.imshow(image,cmap=mpl.cm.binary,**options)
plt.axis("off")
plt.figure(figsize=(9,9))
example_images=np.r_[X[:12000:600],X[13000:30600:600],X[30600:60000:590]]
plot_digits(example_images,images_per_row=10)
plt.show()

接下来,我们需要创建一个测试集,并把其放在一边。
X_train, X_test, y_train, y_test = X[:60000], X[60000:],</

本文探讨了使用MNIST数据集训练和评估机器学习分类器的过程,包括性能度量如精度、召回率和F1分数,以及如何通过混淆矩阵和ROC曲线分析分类器的优缺点。
最低0.47元/天 解锁文章
2493

被折叠的 条评论
为什么被折叠?



