# -*- coding: utf-8 -*-
"""
加载并显示MNIST数据集
PyTorch Programming & Deep Learning
@author: Mike Yuan, Copyright 2020~2021
"""
import matplotlib.pyplot as plt
from torchvision import datasets # 导入PyTorch的视觉数据集模块
rows, cols = 10, 10 # 定义图像显示的网格布局为10行10列
def main():
""" 主函数 """
# 加载MNIST数据集
# root: 数据集存储路径
# train=True: 加载训练集,False为测试集
# download=True: 如果数据集不存在则自动下载
# transform=None: 不进行数据转换(保持原始PIL图像格式)
train_set = datasets.MNIST(
root='../datasets/mnist',
train=True,
download=True,
transform=None
)
test_set = datasets.MNIST(
root='../datasets/mnist',
train=False,
download=True,
transform=None
)
# 获取类别名称(直接通过数据集属性classes获取)
class_names = train_set.classes # 返回类别名称列表,如['0', '1', ..., '9']
# 打印数据集基本信息
print("训练集shape: ", train_set.data.shape) # 输出 [60000, 28, 28]
print("训练集样本数:", len(train_set.targets)) # 输出 60000
print("测试集shape: ", test_set.data.shape) # 输出 [10000, 28, 28]
print("测试集样本数:", len(test_set.targets)) # 输出 10000
# 绘制单个样本图像(索引为idx的样本)
idx = 0 # 可通过修改此值查看不同样本
plt.figure()
plt.imshow(train_set.data[idx], cmap="gray_r") # 显示反色灰度图像
plt.grid(False) # 隐藏网格线
plt.show()
# 绘制10x10网格的样本图像
plt.figure(figsize=(10, 11)) # 设置画布大小(宽度10英寸,高度11英寸)
for i in range(rows * cols):
plt.subplot(rows, cols, i + 1) # 创建子图(编号从1开始)
plt.xticks([]) # 隐藏x轴刻度
plt.yticks([]) # 隐藏y轴刻度
plt.grid(False) # 隐藏网格线
plt.imshow(train_set.data[i], cmap='gray_r') # 显示图像
plt.xlabel(class_names[train_set.targets[i]]) # 显示标签作为子图标题
plt.show()
if __name__ == '__main__':
main()
代码说明
-
数据加载
datasets.MNIST
: 直接加载MNIST数据集。download=True
: 自动下载数据集到root
指定路径(需联网)。transform=None
: 保持原始数据格式(uint8类型,形状[N, 28, 28])。
-
类别名称获取
train_set.classes
: 直接返回类别名称列表(‘0’到’9’)。
-
数据信息打印
train_set.data
: 原始图像数据(未应用transform
)。train_set.targets
: 标签数据。
-
单样本可视化
plt.imshow()
显示原始像素数据,cmap="gray_r"
确保图像为反色灰度。
-
多样本网格可视化
plt.subplot()
创建子图网格,plt.xlabel()
显示数字标签。
注意事项
-
数据集下载问题
- 确保
root='../datasets/mnist'
路径有效且有写入权限。 - 首次运行需保持
download=True
,后续可改为False
以加速加载。
- 确保
-
数据预处理缺失
transform=None
意味着未进行归一化(原始像素值为0~255)。- 建议: 添加
transform=transforms.ToTensor()
以自动归一化到[0,1],例如:transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 进一步归一化到[-1,1] ])
-
数据类型不一致
train_set.data
是uint8类型(0~255),若需用于模型训练,必须转换为float并归一化。
-
索引越界风险
- 确保
rows * cols
不超过数据集大小(例如训练集有60000样本,10x10=100是安全的)。
- 确保
-
标签显示优化
- 若标签显示为方框,需检查系统是否支持默认字体,或指定中文字体(如
plt.xlabel(..., fontproperties='SimHei')
)。
- 若标签显示为方框,需检查系统是否支持默认字体,或指定中文字体(如
-
图像显示细节
cmap='gray_r'
中的_r
表示反色(黑色背景),若需正常显示可改为cmap='gray'
。plt.figure(figsize=(10,11))
的尺寸可根据屏幕分辨率调整,避免图像重叠。
-
代码冗余
test_set
在本代码中未实际使用,可根据需求添加测试集的可视化逻辑。
-
性能问题
- 一次性加载全部数据(如
train_set.data
)可能消耗内存,对于超大数据集建议使用DataLoader分批加载。
- 一次性加载全部数据(如
哎