上代码
class U_Network(nn.Module):
def __init__(self, dim, enc_nf, dec_nf, bn=None, full_size=True):
super(U_Network, self).__init__()
...
...
...
def forward(self, src, tgt):
x = torch.cat([src, tgt], dim=1)
...
...
...
return flow
####################从这里开始分析
UNet = U_Network(len(vol_size), nf_enc, nf_dec).to(device)
##原本数据是[C, D, W, H]
##但经过enumerate(train_dataloader)出来之后,数据就变成了[B, C, D, W, H],新增加了一个维度 B代表BatchSize的意思
for i,(input_moving,labels) in enumerate(train_dataloader):
# Generate the moving images and convert them to tensors.
# [B, C, D, W, H]
input_moving = input_moving.to(device).float()
# Run the data through the model to produce warp and flow field
####UNet(input_moving, input_fixed)就会跳到了U_Network类中的forward函数
###所以src=input_moving tgt=input_fixed
###flow_m2f=flow
flow_m2f = UNet(input_moving, input_fixed)