学习MNIST数据集处理和可视化。

# -*- coding: utf-8 -*-
"""
加载并显示MNIST数据集

PyTorch Programming & Deep Learning
@author: Mike Yuan, Copyright 2020~2021
"""

import matplotlib.pyplot as plt
from torchvision import datasets  # 导入PyTorch的视觉数据集模块

rows, cols = 10, 10  # 定义图像显示的网格布局为10行10列

def main():
    """ 主函数 """
    # 加载MNIST数据集
    # root: 数据集存储路径
    # train=True: 加载训练集,False为测试集
    # download=True: 如果数据集不存在则自动下载
    # transform=None: 不进行数据转换(保持原始PIL图像格式)
    train_set = datasets.MNIST(
        root='../datasets/mnist', 
        train=True,
        download=True, 
        transform=None
    )
    test_set = datasets.MNIST(
        root='../datasets/mnist', 
        train=False,
        download=True, 
        transform=None
    )

    # 获取类别名称(直接通过数据集属性classes获取)
    class_names = train_set.classes  # 返回类别名称列表,如['0', '1', ..., '9']

    # 打印数据集基本信息
    print("训练集shape: ", train_set.data.shape)      # 输出 [60000, 28, 28]
    print("训练集样本数:", len(train_set.targets))     # 输出 60000
    print("测试集shape: ", test_set.data.shape)       # 输出 [10000, 28, 28]
    print("测试集样本数:", len(test_set.targets))      # 输出 10000

    # 绘制单个样本图像(索引为idx的样本)
    idx = 0  # 可通过修改此值查看不同样本
    plt.figure()
    plt.imshow(train_set.data[idx], cmap="gray_r")  # 显示反色灰度图像
    plt.grid(False)  # 隐藏网格线
    plt.show()

    # 绘制10x10网格的样本图像
    plt.figure(figsize=(10, 11))  # 设置画布大小(宽度10英寸,高度11英寸)
    for i in range(rows * cols):
        plt.subplot(rows, cols, i + 1)  # 创建子图(编号从1开始)
        plt.xticks([])   # 隐藏x轴刻度
        plt.yticks([])   # 隐藏y轴刻度
        plt.grid(False)  # 隐藏网格线
        plt.imshow(train_set.data[i], cmap='gray_r')  # 显示图像
        plt.xlabel(class_names[train_set.targets[i]])  # 显示标签作为子图标题
    plt.show()

if __name__ == '__main__':
    main()

代码说明

  1. 数据加载

    • datasets.MNIST: 直接加载MNIST数据集。
    • download=True: 自动下载数据集到root指定路径(需联网)。
    • transform=None: 保持原始数据格式(uint8类型,形状[N, 28, 28])。
  2. 类别名称获取

    • train_set.classes: 直接返回类别名称列表(‘0’到’9’)。
  3. 数据信息打印

    • train_set.data: 原始图像数据(未应用transform)。
    • train_set.targets: 标签数据。
  4. 单样本可视化

    • plt.imshow()显示原始像素数据,cmap="gray_r"确保图像为反色灰度。
  5. 多样本网格可视化

    • plt.subplot()创建子图网格,plt.xlabel()显示数字标签。

注意事项

  1. 数据集下载问题

    • 确保root='../datasets/mnist'路径有效且有写入权限。
    • 首次运行需保持download=True,后续可改为False以加速加载。
  2. 数据预处理缺失

    • transform=None意味着未进行归一化(原始像素值为0~255)。
    • 建议: 添加transform=transforms.ToTensor()以自动归一化到[0,1],例如:
      transform=transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.5,), (0.5,))  # 进一步归一化到[-1,1]
      ])
      
  3. 数据类型不一致

    • train_set.data是uint8类型(0~255),若需用于模型训练,必须转换为float并归一化。
  4. 索引越界风险

    • 确保rows * cols不超过数据集大小(例如训练集有60000样本,10x10=100是安全的)。
  5. 标签显示优化

    • 若标签显示为方框,需检查系统是否支持默认字体,或指定中文字体(如plt.xlabel(..., fontproperties='SimHei'))。
  6. 图像显示细节

    • cmap='gray_r'中的_r表示反色(黑色背景),若需正常显示可改为cmap='gray'
    • plt.figure(figsize=(10,11))的尺寸可根据屏幕分辨率调整,避免图像重叠。
  7. 代码冗余

    • test_set在本代码中未实际使用,可根据需求添加测试集的可视化逻辑。
  8. 性能问题

    • 一次性加载全部数据(如train_set.data)可能消耗内存,对于超大数据集建议使用DataLoader分批加载。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值