import torch
class MyNet(torch.nn.Module):
def __init__(self):
# 必须调用父类的构造函数,因为想要使用父类的方法,这也是继承Module的目的
super(MyNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu1 = torch.nn.ReLU()
self.max_pooling1 = torch.nn.MaxPool2d(2, 1)
self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu2 = torch.nn.ReLU()
self.max_pooling2 = torch.nn.MaxPool2d(2, 1)
self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
self.dense2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.max_pooling1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.max_pooling2(x)
x = self.dense1