利用上一篇文章搭建的卷积神经网络进行模型训练。
将搭建的卷积神经网络放在model.py中。
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
class Gao(nn.Module):
def __init__(self):
super(Gao, self).__init__()
self.model = nn.Sequential(
Conv2d(3, 32, 5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, stride=1, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
if __name__ == '__main__':
gao=Gao()
input = torch.ones((64, 3, 32, 32))
output = gao(input)
print(output.shape)
下面是训练模型的主要代码:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWrit