pytorch系列教程(八)-对forward函数的重新理解

本文介绍了U-Network模型的构造及应用过程。该模型通过定义类`U_Network`实现,利用PyTorch框架进行搭建,并详细展示了如何将输入数据传递给模型以生成flow场。在训练过程中,输入数据形状由[C,D,W,H]变为[B,C,D,W,H],新增的维度B代表Batch Size。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上代码

class U_Network(nn.Module):
  def __init__(self, dim, enc_nf, dec_nf, bn=None, full_size=True):
        super(U_Network, self).__init__()
        ...
        ...
        ...
  def forward(self, src, tgt):
      x = torch.cat([src, tgt], dim=1)
      ...
      ...
      ...
      return flow

####################从这里开始分析
UNet = U_Network(len(vol_size), nf_enc, nf_dec).to(device)
##原本数据是[C, D, W, H] 
##但经过enumerate(train_dataloader)出来之后,数据就变成了[B, C, D, W, H],新增加了一个维度 B代表BatchSize的意思
for i,(input_moving,labels) in enumerate(train_dataloader):
      # Generate the moving images and convert them to tensors.
      
      # [B, C, D, W, H]
      input_moving = input_moving.to(device).float()

      # Run the data through the model to produce warp and flow field
      ####UNet(input_moving, input_fixed)就会跳到了U_Network类中的forward函数
      ###所以src=input_moving  tgt=input_fixed
      ###flow_m2f=flow
      flow_m2f = UNet(input_moving, input_fixed)    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值