""" ## Multi-Stage Progressive Image Restoration ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao ## https://arxiv.org/abs/2102.02808 """ import torch import torch.nn as nn import torch.nn.functional as F from pdb import set_trace as stx ########################################################################## def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride) ########################################################################## ## Channel Attention Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16, bias=False): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ########################################################################## ## Channel Attention Block (CAB) class CAB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, bias, act): super(CAB, self).__init__() modules_body = [] modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) modules_body.append(act) modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) # self.CA = CALayer(n_feat, reduction, bias=bias) self.CA = nn.MultiheadAttention(n_feat, num_heads=8) self.body = nn.Sequential(*modules_body) def forward(self, x): batch_size, channels, height, width = x.size() x_reshape = x.view(batch_size, channels, -1).permute(0, 2, 1) res, _ = self.CA(x_reshape, x_reshape, x_reshape) res = res.permute(0, 2, 1).view(batch_size, channels, height, width) res = self.body(x) + res # res = self.body(x) # res = self.CA(res) # res += x return res ########################################################################## ## Supervised Attention Module class SAM(nn.Module): def __init__(self, n_feat, kernel_size, bias): super(SAM, self).__init__() self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) self.conv5 = conv(n_feat, n_feat, 5, bias=bias) self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) self.conv3 = conv(3, n_feat, kernel_size, bias=bias) self.bn = nn.BatchNorm2d(n_feat) self.relu = nn.ReLU() def forward(self, x, x_img): x1 = self.conv1(x) x1 = self.conv5(x1) x1 = self.bn(x1) x1 = self.relu(x1) img = self.conv2(x) + x_img x2 = torch.sigmoid(self.conv3(img)) x1 = x1 * x2 x1 = x1 + x return x1, img ########################################################################## ## U-Net class Encoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): super(Encoder, self).__init__() self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level1 = nn.Sequential(*self.encoder_level1) self.encoder_level2 = nn.Sequential(*self.encoder_level2) self.encoder_level3 = nn.Sequential(*self.encoder_level3) self.down12 = DownSample(n_feat, scale_unetfeats) self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) # Cross Stage Feature Fusion (CSFF) if csff: self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias) self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias) self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias) def forward(self, x, encoder_outs=None, decoder_outs=None): enc1 = self.encoder_level1(x) if (encoder_outs is not None) and (decoder_outs is not None): enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) x = self.down12(enc1) enc2 = self.encoder_level2(x) if (encoder_outs is not None) and (decoder_outs is not None): enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) x = self.down23(enc2) enc3 = self.encoder_level3(x) if (encoder_outs is not None) and (decoder_outs is not None): enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) return [enc1, enc2, enc3] class Decoder(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): super(Decoder, self).__init__() self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level1 = nn.Sequential(*self.decoder_level1) self.decoder_level2 = nn.Sequential(*self.decoder_level2) self.decoder_level3 = nn.Sequential(*self.decoder_level3) self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) self.up21 = SkipUpSample(n_feat, scale_unetfeats) self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) def forward(self, outs): enc1, enc2, enc3 = outs dec3 = self.decoder_level3(enc3) x = self.up32(dec3, self.skip_attn2(enc2)) dec2 = self.decoder_level2(x) x = self.up21(dec2, self.skip_attn1(enc1)) dec1 = self.decoder_level1(x) return [dec1, dec2, dec3] ########################################################################## ##---------- Resizing Modules ---------- # class DownSample(nn.Module): # def __init__(self, in_channels, s_factor): # super(DownSample, self).__init__() # self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), # nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False)) # # def forward(self, x): # x = self.down(x) # return x # # # class UpSample(nn.Module): # def __init__(self, in_channels, s_factor): # super(UpSample, self).__init__() # self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), # nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) # # def forward(self, x): # x = self.up(x) # return x class DownSample(nn.Module): def __init__(self, in_channels, s_factor): super(DownSample, self).__init__() self.down = nn.Sequential( nn.Conv2d(in_channels, in_channels + s_factor, kernel_size=3, stride=2, padding=1), # 使用步长为2的卷积进行下采样 ) def forward(self, x): x = self.down(x) return x class UpSample(nn.Module): def __init__(self, in_channels, s_factor): super(UpSample, self).__init__() self.up = nn.Sequential( nn.ConvTranspose2d(in_channels + s_factor, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # 使用转置卷积进行上采样 ) def forward(self, x): x = self.up(x) return x class SkipUpSample(nn.Module): def __init__(self, in_channels, s_factor): super(SkipUpSample, self).__init__() self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) def forward(self, x, y): x = self.up(x) x = x + y return x ########################################################################## ## Original Resolution Block (ORB) class ORB(nn.Module): def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): super(ORB, self).__init__() modules_body = [] modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ########################################################################## class ORSNet(nn.Module): def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): super(ORSNet, self).__init__() self.orb1 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb2 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.orb3 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.up_enc1 = UpSample(n_feat, scale_unetfeats) self.up_dec1 = UpSample(n_feat, scale_unetfeats) self.up_enc2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) self.up_dec2 = nn.Sequential(UpSample(n_feat + scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) self.conv_enc1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_enc3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) self.conv_dec3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) def forward(self, x, encoder_outs, decoder_outs): x = self.orb1(x) x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) x = self.orb2(x) x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) x = self.orb3(x) x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) return x ########################################################################## class MPRNet(nn.Module): def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False): super(MPRNet, self).__init__() act = nn.PReLU() self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act)) # Cross Stage Feature Fusion (CSFF) self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias) self.concat23 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias) self.tail = conv(n_feat + scale_orsnetfeats, out_c, kernel_size, bias=bias) def forward(self, x3_img): # Original-resolution Image for Stage 3 H = x3_img.size(2) W = x3_img.size(3) # Multi-Patch Hierarchy: Split Image into four non-overlapping patches # Two Patches for Stage 2 x2top_img = x3_img[:, :, 0:int(H / 2), :] x2bot_img = x3_img[:, :, int(H / 2):H, :] # Four Patches for Stage 1 x1ltop_img = x2top_img[:, :, :, 0:int(W / 2)] x1rtop_img = x2top_img[:, :, :, int(W / 2):W] x1lbot_img = x2bot_img[:, :, :, 0:int(W / 2)] x1rbot_img = x2bot_img[:, :, :, int(W / 2):W] ##------------------------------------------- ##-------------- Stage 1--------------------- ##------------------------------------------- ## Compute Shallow Features x1ltop = self.shallow_feat1(x1ltop_img) x1rtop = self.shallow_feat1(x1rtop_img) x1lbot = self.shallow_feat1(x1lbot_img) x1rbot = self.shallow_feat1(x1rbot_img) ## Process features of all 4 patches with Encoder of Stage 1 feat1_ltop = self.stage1_encoder(x1ltop) feat1_rtop = self.stage1_encoder(x1rtop) feat1_lbot = self.stage1_encoder(x1lbot) feat1_rbot = self.stage1_encoder(x1rbot) ## Concat deep features feat1_top = [torch.cat((k, v), 3) for k, v in zip(feat1_ltop, feat1_rtop)] feat1_bot = [torch.cat((k, v), 3) for k, v in zip(feat1_lbot, feat1_rbot)] ## Pass features through Decoder of Stage 1 res1_top = self.stage1_decoder(feat1_top) res1_bot = self.stage1_decoder(feat1_bot) ## Apply Supervised Attention Module (SAM) x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) ## Output image at Stage 1 stage1_img = torch.cat([stage1_img_top, stage1_img_bot], 2) ##------------------------------------------- ##-------------- Stage 2--------------------- ##------------------------------------------- ## Compute Shallow Features x2top = self.shallow_feat2(x2top_img) x2bot = self.shallow_feat2(x2bot_img) ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) ## Process features of both patches with Encoder of Stage 2 feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) ## Concat deep features feat2 = [torch.cat((k, v), 2) for k, v in zip(feat2_top, feat2_bot)] ## Pass features through Decoder of Stage 2 res2 = self.stage2_decoder(feat2) ## Apply SAM x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) ##------------------------------------------- ##-------------- Stage 3--------------------- ##------------------------------------------- ## Compute Shallow Features x3 = self.shallow_feat3(x3_img) ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) stage3_img = self.tail(x3_cat) return [stage3_img + x3_img, stage2_img, stage1_img] def main(): # 创建一个实例化的MPRNet模型 model = MPRNet() # 将模型设置为评估模式以关闭dropout和batch normalization的训练特性 model.eval() # 创造一个随机输入图像 (1, 3, 256, 256) 作为测试 input_image = torch.randn(1, 3, 256, 256) # 关闭梯度计算以加快推理 with torch.no_grad(): # 向模型输入图像并获取输出 outputs = model(input_image) # 打印每个阶段输出的形状 print("Stage 3 output shape:", outputs[0].shape) print("Stage 2 output shape:", outputs[1].shape) print("Stage 1 output shape:", outputs[2].shape) if __name__ == "__main__": main()
demo_MPRNet
于 2024-10-15 09:45:33 首次发布