文章目录
-
-
- 1. 模块导入与全局变量
- 2. 主函数 `main()`
- 3. 打印数据集基本信息
- 4. 绘制单个样本图像
- 5. 数据归一化
- 6. 绘制样本网格
- 7. 程序入口
- 关键问题与改进建议
- 总结
1. 模块导入与全局变量
"""
加载并显示CIFAR10数据集
PyTorch Programming & Deep Learning
@author: Mike Yuan, Copyright 2020~2021
"""
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
rows, cols = 10, 10
- 功能:导入必要的库,定义全局常量
rows
和 cols
,用于后续的图像网格展示。
2. 主函数 main()
def main():
""" 主函数 """
cifar10_train = torchvision.datasets.cifar.CIFAR10(
root='../datasets/CIFAR10',
train=True,
download=False,
transform=transforms.ToTensor()
)
train_images, train_labels = cifar10_train.data, cifar10_train