问题
我在复现Attention Unet代码时,出现了如下报错信息:
File "/root/autodl-tmp/projects/UnetWithResNet/model/Attention_UNet.py", line 163, in <module>
summary(model, (3, 512, 512))
File "/root/miniconda3/lib/python3.8/site-packages/torchsummary/torchsummary.py", line 72, in summary
model(*x)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/root/autodl-tmp/projects/UnetWithResNet/model/Attention_UNet.py", line 134, in forward
python-BaseException
attention1 = self.attention1(g=up1, x=down4) # 512, 64, 64
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
hook_result = hook(self, input, result)
File "/root/miniconda3/lib/python3.8/site-packages/torchsummary/torchsummary.py", line 19, in hook
summary[m_key]["input_shape"] = list(input[0].size())
IndexError: tuple index out of range
报错信息显示代码在运行到
attention1 = self.attention1(g=up1, x=down4) # 512, 64, 64
时,torchsummary.py
文件的
summary[m_key]["input_shape"] = list(input[0].size())
出现了元组索引溢出异常。
解决方法
方法一
博主一元童鞋给出了一个解决方法,就是注释掉torchsummary.py
文件的:
# summary[m_key]["input_shape"] = list(input[0].size())
# summary[m_key]["input_shape"][0] = batch_size
并加入:
if len(input) != 0:
summary[m_key]["input_shape"] = list(input[0].size())
summary[m_key]["input_shape"][0] = batch_size
else:
summary[m_key]["input_shape"] = input
因为我是在远程服务器上跑的代码,直接修改torchsummary.py
会很麻烦,所以该方法是否可以解决问题,我没有进行尝试。
方法二
观察我的代码:
attention1 = self.attention1(g=up1, x=down4) # 512, 64, 64
这段代码的实现是:
self.attention1 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
Attention_block
类的实现是:
class Attention_block(nn.Module):
"""注意力机制"""
def __init__(self, F_g, F_l, F_int):
super(Attention_block, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential( # psi是一个卷积层,用于将F_int通道的特征图转换为1通道的特征图
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g) # 512, 32, 32 -> 256, 32, 32
x1 = self.W_x(x) # 512, 32, 32 -> 256, 32, 32
psi = self.relu(g1 + x1) # 256, 32, 32
psi = self.psi(psi) # 256, 32, 32 -> 1, 32, 32
return x * psi # 512, 32, 32
这有可能是调用self.attention1
时,参数的传递出了问题,将其改为:
attention1 = self.attention1(up1, down4) # 512, 64, 64
也就是将:
attention1 = self.attention1(g=up1, x=down4) # 512, 64, 64
改为
attention1 = self.attention1(up1, down4) # 512, 64, 64
up1
和down4
应该是不需要指定参数名称的,只需按顺序传入即可。但是原因到底是不是这样,我没有进行深究。
参考
https://blog.youkuaiyun.com/onermb/article/details/116149599
https://github.com/LeeJunHyun/Image_Segmentation/issues/69