论文地址:https://arxiv.org/abs/1802.06955
这个模型把recurrent模块融入到UNet结构中的图像分割模型,至于效果见仁见智。
模型结构如图
用pytorch复现的代码如下:
首先需要复现的是其中的recurrent模块
class Recurrent_block(nn.Module):
def __init__(self, out_ch, t=2):
super(Recurrent_block, self).__init__()
self.t = t
self.out_ch = out_ch
self.conv = nn.Sequential(
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
for i in range(self.t):
if i == 0:
out = self.conv(x)
out = self.conv(out + x)
return out
class RRCNN_block(nn.Module):
def __init__(self, in_ch, out_ch, t=2):
super(RRCNN_block, self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(out_ch, t=t),
Recurrent_block(out_ch, t=t)
)
self.Conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.Conv(x)
x2 = self.RCNN(x1)
out = x1 + x2
return out
1,其中for循环这块就是recurrent的核心,通过将卷积层输出的结果和自身的输入相加,然后再次输入到本层的卷积之中,这样就实现了recurrent。
2,其中t=2意味着循环两次,这个数字可以根据需要修改
3,原文中提到了两种种recurrent模块的结构,下图的b和d
这里实现的是b,如果需要实现d,就在forward里把输入x也加到out上。这个可以自行修改。
解码分支的上采样实现如下:
class up_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
模型整个结构如下:
class R2UNet(nn.Module):
def __init__(self, img_ch=3, output_ch=2, t=2):
super(R2UNet, self).__init__()
n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.Upsample = nn.Upsample(scale_factor=2)
self.RRCNN1 = RRCNN_block(img_ch, filters[0], t=t)
self.RRCNN2 = RRCNN_block(filters[0], filters[1], t=t)
self.RRCNN3 = RRCNN_block(filters[1], filters[2], t=t)
self.RRCNN4 = RRCNN_block(filters[2], filters[3], t=t)
self.RRCNN5 = RRCNN_block(filters[3], filters[4], t=t)
self.Up5 = up_conv(filters[4], filters[3])
self.Up_RRCNN5 = RRCNN_block(filters[4], filters[3], t=t)
self.Up4 = up_conv(filters[3], filters[2])
self.Up_RRCNN4 = RRCNN_block(filters[3], filters[2], t=t)
self.Up3 = up_conv(filters[2], filters[1])
self.Up_RRCNN3 = RRCNN_block(filters[2], filters[1], t=t)
self.Up2 = up_conv(filters[1], filters[0])
self.Up_RRCNN2 = RRCNN_block(filters[1], filters[0], t=t)
self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
e1 = self.RRCNN1(x)
e2 = self.Maxpool(e1)
e2 = self.RRCNN2(e2)
e3 = self.Maxpool1(e2)
e3 = self.RRCNN3(e3)
e4 = self.Maxpool2(e3)
e4 = self.RRCNN4(e4)
e5 = self.Maxpool3(e4)
e5 = self.RRCNN5(e5)
d5 = self.Up5(e5)
d5 = torch.cat((e4, d5), dim=1)
d5 = self.Up_RRCNN5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((e3, d4), dim=1)
d4 = self.Up_RRCNN4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((e2, d3), dim=1)
d3 = self.Up_RRCNN3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((e1, d2), dim=1)
d2 = self.Up_RRCNN2(d2)
out = self.Conv(d2)
return out
最后,我需要提醒的是,这篇论文的结果是由争议的。有些后续研究的实验部分复现了该模型,并且对比了结果。有些研究者则宣传这个模型结果不行。
而且由于使用了recurrent模块,训练很困难。我的复现只能保证理论上是正确的,最终跑出什么结果,不能保证。