pytorch模型

本文介绍了PyTorch中构建模型的方法,包括使用nn.Sequential、nn.ModuleList和nn.ModuleDict。详细阐述了它们的区别,并通过U-Net模型展示了复杂模型的搭建步骤。此外,还讨论了如何修改已有模型,如替换层、增加输入变量和输出变量。最后,详细说明了模型的保存和加载,包括单卡和多卡场景下的处理。

模型基本定义方法

pytorch中有提供nn.Sequential()、nn.ModuleList()以及nn.ModuleDict()用于集成多个Module,完成模型搭建。其异同如下:

Sequential() ModuleList() /ModuleDict()
直接搭建网络,定义顺序即为模型连接顺序 List/Dict中元素顺序并不代表其在网络中的真实位置顺序,需要forward函数指定各个层的连接顺序
模型中间无法加入外部输入 模型中间需要之前层的信息的时候,比如 ResNets 中的残差计算,比较方便

通过nn.Sequential()

# 方法一:
import torch.nn as nn
net = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
# 方法二:
import collections
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)),
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
          ]))

通过nn.ModuleList()/nn.ModuleDict()

# List
class model(nn.Module):
  def __init__(self):
    super().__init__()
    self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU(),nn.Linear(256, 10)])
    
  def forward(self, x):
    for layer in self.modulelist:
      x = layer(x)
    return x
# Dict
class model(nn.Module):
  def __init__(self):
    super().__init__()
    self.moduledict = nn.ModuleDict({
   
   
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
    'output':nn.Linear(256, 10)
    })
    
  def forward(self, x):
    for layer in self.moduledict:
      x = layer(x)
    return x

复杂模型搭建方法

对于大型复杂模型,可以先将模型分块,然后在进行模型搭建。以U-Net模型为例。
在这里插入图片描述
上图为U-Net网络结构,可以分为以下四个模块:

  • 每个子块内部的两次卷积(Double Convolution)
  • 左侧模型块之间的下采样连接,即最大池化(Max pooling)
  • 右侧模型块之间的上采样连接(Up sampling)
  • 输出层的处理

模块构建

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值