目录
模型基本定义方法
pytorch中有提供nn.Sequential()、nn.ModuleList()以及nn.ModuleDict()用于集成多个Module,完成模型搭建。其异同如下:
| Sequential() | ModuleList() /ModuleDict() |
|---|---|
| 直接搭建网络,定义顺序即为模型连接顺序 | List/Dict中元素顺序并不代表其在网络中的真实位置顺序,需要forward函数指定各个层的连接顺序 |
| 模型中间无法加入外部输入 | 模型中间需要之前层的信息的时候,比如 ResNets 中的残差计算,比较方便 |
通过nn.Sequential()
# 方法一:
import torch.nn as nn
net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# 方法二:
import collections
net2 = nn.Sequential(collections.OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 10))
]))
通过nn.ModuleList()/nn.ModuleDict()
# List
class model(nn.Module):
def __init__(self):
super().__init__()
self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU(),nn.Linear(256, 10)])
def forward(self, x):
for layer in self.modulelist:
x = layer(x)
return x
# Dict
class model(nn.Module):
def __init__(self):
super().__init__()
self.moduledict = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
'output':nn.Linear(256, 10)
})
def forward(self, x):
for layer in self.moduledict:
x = layer(x)
return x
复杂模型搭建方法
对于大型复杂模型,可以先将模型分块,然后在进行模型搭建。以U-Net模型为例。

上图为U-Net网络结构,可以分为以下四个模块:
- 每个子块内部的两次卷积(Double Convolution)
- 左侧模型块之间的下采样连接,即最大池化(Max pooling)
- 右侧模型块之间的上采样连接(Up sampling)
- 输出层的处理
模块构建
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True)

本文介绍了PyTorch中构建模型的方法,包括使用nn.Sequential、nn.ModuleList和nn.ModuleDict。详细阐述了它们的区别,并通过U-Net模型展示了复杂模型的搭建步骤。此外,还讨论了如何修改已有模型,如替换层、增加输入变量和输出变量。最后,详细说明了模型的保存和加载,包括单卡和多卡场景下的处理。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



