1.首先是写一个nn
from torch import nn
import torch
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model=nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,32, 5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32,64,5,1,2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self, x):
x = self.model(x)
return x
2.第二部加载数据集,定义模型,进行训练,训练过之后会保存一个权重文件,你可以加载这个权重文件进行再训练,迁移学习。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
#from model import Net
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,