我写这篇的目的主要是想熟悉一下PyTorch搭建模型的方法。
一. AlexNet
五个卷积层加3个全连接层,话不多说,直接上代码:
import torch
from torch import nn
from torchstat import stat
class AlexNet(nn.Module):
def __init__(self, num_classes):
super(AlexNet, self).__init__() # b, 3, 224, 224
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # b, 64, 55, 55
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2), # b, 64, 27, 27
nn.Conv2d(64, 192, kernel_size=5, padding=2), # b, 192, 27, 27
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2), # b, 192, 13, 13
nn.Conv2d(192, 384, kernel_size=3, padding=1), # b, 384, 13, 13
nn.ReLU(True),
nn.Conv2d(384, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # b, 256, 13, 13
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3, stride=2)) # b, 256, 6, 6
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256*6*6, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Linear(4096, num_classes))
def forward(self, x):
x = self.features(x)
print(x.size())
x = x.view(x.size(0), 256*6*6)
x = self.classifier(x)
return x
model = AlexNet(10)
stat(model, (3, 224, 224))
使用stat模块对模型参数量和计算量进行估计,顺便也验证了模型是否正确,运行结果:
torch.Size([1, 256, 6, 6])
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
[MAdd]: Dropout is not supported!
[Flops]: Dropout is not supported!
[Memory]: Dropout is not supported!
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 features.0 3 224 224 64 55 55 23296.0 0.74 140,553,600.0 70,470,400.0 695296.0 774400.0 55.56% 1469696.0
1 features.1 64 55 55 64 55 55 0.0 0.74 193,600.0 193,600.0 774400.0 774400.0 0.00% 1548800.0
2 features.2 64 55 55 64 27 27 0.0 0.18 373,248.0 193,600.0 774400.0 186624.0 5.57% 961024.0
3 features.3 64 27 27 192 27 27 307392.0 0.53 447,897,600.0 224,088,768.0 1416192.0 559872.0 22.21% 1976064.0
4 features.4 192 27 27 192 27 27 0.0 0.53 139,968.0 139,968.0 559872.0 559872.0 0.00% 1119744.0
5 features.5 192 27 27 192 13 13 0.0 0.12 259,584.0 139,968.0 559872.0 129792.0 0.00% 689664.0
6 features.6 192 13 13 384 13 13 663936.0 0.25 224,280,576.0 112,205,184.0 2785536.0 259584.0 0.00% 3045120.0
7 features.7 384 13 13 384 13 13 0.0 0.25 64,896.0 64,896.0 259584.0 259584.0 0.00% 519168.0
8 features.8 384 13 13 256 13 13 884992.0 0.17 299,040,768.0 149,563,648.0 3799552.0 173056.0 5.56% 3972608.0
9 features.9 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 0.00% 346112.0
10 features.10 256 13 13 256 13 13 590080.0 0.17 199,360,512.0 99,723,520.0 2533376.0 173056.0 0.00% 2706432.0
11 features.11 256 13 13 2