nn.Module
包含许多现成的层,如
▪ Linear
▪ ReLU
▪ Sigmoid
▪ Conv2d
▪ ConvTransposed2d
▪ Dropout
▪ etc
Container
self.net = nn.Sequential(
)
方便管理参数
list(net.parameters())[1] #第0层的bias
save and load
torch.save(net.state_dict(), ‘ckpt.mdl’)
net.load_state_dict(torch.load(‘ckpt.mdl’))
train and test
在Dropout等网络中,train与test表现不同,模块可以方便地进行切换
net.train() #train
net.eval() #test
自定义类
将2维打平操作
class MyLinear(nn.Module):
def __init__(self, inp, outp):
super(MyLinear, self).__init__()
# requires_grad = True
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self, x):
x = x @ self.w.t() + self.b
return x
自定义类时,需要将参数加入到parameter中
数据增强
▪ Flip 翻转
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
]))
▪ Rotate 旋转
transforms.RandomRotation(15), #最大+/- 15。
transforms.RandomRotation([90, 180, 270]), #共有4种操作
Scale 缩放
transforms.Resize([32, 32]), #参数为调整之后大小,注意其是一个参数,【】
▪ Random Move & Crop 旋转加裁剪
transforms.RandomRotation(15),
transforms.RandomCrop([28, 28]),