1.导入所需的库
import torch
from torch import nn
import torchvision.datasets
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
2.定义神经网络模型:可以设计一个由卷积层和全连接层组成的简单神经网络模型。例如,以下示例代码定义了一个包含两个卷积层和三个全连接层的网络。
class CK(nn.Module):
def __init__(self):
super(CK, self).__init__()
self.model=nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,64,5,1,2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,x):
x=self.model(x)
return x
ck=CK()
3.加载数据集:可以使用PyTorch提供的datasets.CIFAR10
类加载CIFAR10数据集,同时需要对图像进行预处理。
tr