导入必要的包
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
通过transform 实现对数据进行处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
数据加载:
trainset = torchvision.datasets.CIFAR10(root=r'./data',
train=True,download=False,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size= 4, shuffle=False, num_workers=0)
testset = torchvision.datasets.CIFAR10(root=r'./data',
train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird','cat','deer',
'dog'

本文详细介绍了使用PyTorch深度学习框架对CIFAR-10数据集进行图像分类的过程,包括数据预处理、模型定义、训练、评估及预测。展示了如何构建卷积神经网络,设置损失函数和优化器,以及利用GPU加速训练。
最低0.47元/天 解锁文章
9059

被折叠的 条评论
为什么被折叠?



