pytorch实现cifar10分类

该博客演示了如何使用PyTorch库训练ResNet18模型来处理CIFAR10图像分类任务。首先,定义了超参数如迭代次数、批量大小和学习率。接着,加载并预处理CIFAR10数据集,通过DataLoader进行分批处理。然后,建立ResNet模型,选择交叉熵损失函数和Adam优化器。在GPU上进行模型训练,并保存训练好的模型。最后,对模型进行测试,输出测试图像的准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader


# 超参数定义
EPOCH = 1000
BATCH_SIZE = 64
LR = 0.001

# 数据加载
train_data = datasets.CIFAR10(root='/root/cifar10', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.CIFAR10(root='/root/cifar10', train=False, transform=transforms.ToTensor(), download=True)

# 输出图像
"""temp = train_data[1][0].numpy()
print(temp.shape)
temp = temp.transpose(1, 2, 0)
print(temp.shape)
plt.imshow(temp)
plt.show()"""

# 使用DataLoader进行分批
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)

# 使用ResNet Model
model = torchvision.models.resnet18(pretrained=False)

# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=LR)

# device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)


import time
# 训练
for epoch in range(EPOCH):
    start_time = time.time()
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # 前向传播
        outputs = model(inputs)
        # 计算损失函数
        loss = criterion(outputs, labels)
        # 清空上一轮梯度
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 参数更新
        optimizer.step()
    print('epoch{:} loss:{:.4f}'.format(epoch+1, loss.item(), time.time() - start_time))
        
        
# 保存训练模型
file_name = 'cifar10_resnet.pt'
torch.save(model, file_name)
print('model saved')

# 测试
model = torch.load(file_name)
model.eval()
correct, total = 0, 0

for data in test_loader:
    images, labels = data
    images, labels = images.to(device), labels.to(device)
    # 前向传播
    out = model(images)
    _, predicted = torch.max(out.data, 1)
    total = total + labels.size(0)
    correct += (predicted == labels).sum().item()
    
print('10000张测试图像 准确率{:.4}%'.format(100.0 * correct / total))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值