知识点回顾
1.图像数据的格式:灰度和彩色数据
2.模型的定义
3.显存占用的4种地方
a.模型参数+梯度参数
b.优化器参数
c.数据批量所占显存
d.神经元输出中间状态
4.batchisize和训练的关系
作业:今日代码较少,理解内容即可
# 打印一张彩色图像,用cifar-10数据集
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子确保结果可复现
torch.manual_seed(42)
# 定义数据预处理步骤
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量并归一化到[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化处理
])
# 加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# 创建数据加载器
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True
)
# CIFAR-10的10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 随机选择一张图片
sample_idx = torch.randint(0, len(trainset), size=(1,)).item()
image, label = trainset[sample_idx]
# 打印图片形状
print(f"图像形状: {image.shape}") # 输出: torch.Size([3, 32, 32])
print(f"图像类别: {classes[label]}")
# 定义图像显示函数(适用于CIFAR-10彩色图像)
def imshow(img):
img = img / 2 + 0.5 # 反标准化处理,将图像范围从[-1,1]转回[0,1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 调整维度顺序:(通道,高,宽) → (高,宽,通道)
plt.axis('off') # 关闭坐标轴显示
plt.show()
# 显示图像
imshow(image)