#---Part1--- SSAnetwork.py
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 27 13:43:52 2022
@author: User
"""
import torch.nn as nn
def weights_init(mod):
"""
Custom weights initialization called on netG, netD and netE
:param m:
:return:
"""
classname = mod.__class__.__name__
if classname.find('Conv') != -1:
mod.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
mod.weight.data.normal_(1.0, 0.02)
mod.bias.data.fill_(0)
# ==================================================Alister add 2022-10-14======================
import torch.nn as nn
import torch.nn.functional as F
import torch
from network.attension import ChannelAttention, SpatialAttention
from network.SelfAttension import Self_Attn
import torchvision.models as models
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# U-NET
##############################
# torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
'''
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
self.leakyrelu = nn.LeakyReLU(0.2)
self.conv = nn.Conv2d(in_size, out_size, 4, 2, 1)
self.ca = ChannelAttention(out_size)
self.sa = SpatialAttention()
self.norm = nn.InstanceNorm2d(out_size)
self.normalize = normalize
self.dropout = dropout
self.drop = nn.Dropout()
#layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1))
#if normalize:
#layers.append(nn.InstanceNorm2d(out_size))
#if dropout:
#layers.append(nn.Dropout(dropout))
#self.model = nn.Sequential(*layers)
def forward(self, x):
x = self.leakyrelu(x)
xc = self.conv(x)
xa = self.ca(xc)*xc
xa = self.sa(xa)*xa
x = xc + xa
if self.normalize:
x = self.norm(x)
if self.dropout:
x = self.drop(x)
return x
'''
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.LeakyReLU(0.2)]
layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1))
if normalize:
layers.append(nn.InstanceNorm2d(out_size))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(self, in_size, out_size):
super(UNetUp, self).__init__()
self.up = nn.Sequential(
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_size),
nn.ReLU(True)
)
# 修改卷积层的输入通道数,考虑拼接后的通道数
self.conv = nn.Sequential(
nn.Conv2d(out_size + out_size, out_size, 3, 1, 1, bias=False), # 输入是out_size*2
nn.BatchNorm2d(out_size),
nn.ReLU(True)
)
def forward(self, x, skip_input):
x = self.up(x)
# 确保skip_input的尺寸与x匹配
if skip_input.size() != x.size():
skip_input = F.interpolate(skip_input, size=x.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, skip_input), 1)
x = self.conv(x)
return x
class ResNet50FeatureExtractor(nn.Module):
def __init__(self, pretrained=True):
super(ResNet50FeatureExtractor, self).__init__()
resnet = models.resnet50(pretrained=pretrained)
self.conv1 = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool
)
self.layer1 = resnet.layer1 # 256 channels
self.layer2 = resnet.layer2 # 512 channels
self.layer3 = resnet.layer3 # 1024 channels
self.layer4 = resnet.layer4 # 2048 channels
def forward(self, x):
x1 = self.conv1(x) # 1/4
x2 = self.layer1(x1) # 1/4
x3 = self.layer2(x2) # 1/8
x4 = self.layer3(x3) # 1/16
x5 = self.layer4(x4) # 1/32
return [x1, x2, x3, x4, x5]
class SSAEncoder(nn.Module):
"""
UNET ENCODER NETWORK
"""
def __init__(self, isize=32, nz=100, nc=3, ndf=64, ngpu=1, n_extra_layers=0, add_final_conv=True):
super(SSAEncoder, self).__init__()
self.isize = isize
# 使用ResNet50作为特征提取器
self.resnet = ResNet50FeatureExtractor(pretrained=True)
# 特征转换层
self.conv1 = nn.Conv2d(64, 64, 1) # 转换conv1输出
self.conv2 = nn.Conv2d(256, 128, 1) # 转换layer1输出
self.conv3 = nn.Conv2d(512, 256, 1) # 转换layer2输出
self.conv4 = nn.Conv2d(1024, 512, 1) # 转换layer3输出
self.conv5 = nn.Conv2d(2048, 512, 1) # 转换layer4输出
if add_final_conv:
self.final_conv = nn.Conv2d(512, nz, 4, 1, 0, bias=False)
# 注意力模块
self.ca1 = ChannelAttention(64)
self.sa1 = SpatialAttention()
self.ca2 = ChannelAttention(128)
self.sa2 = SpatialAttention()
self.ca3 = ChannelAttention(256)
self.sa3 = SpatialAttention()
self.ca4 = ChannelAttention(512)
self.sa4 = SpatialAttention()
self.ca5 = ChannelAttention(512)
self.sa5 = SpatialAttention()
def forward(self, x):
# 获取ResNet50的多尺度特征
features = self.resnet(x)
# 转换特征维度并应用注意力机制
d1 = self.conv1(features[0])
d1 = self.ca1(d1) * d1
d1 = self.sa1(d1) * d1
d2 = self.conv2(features[1])
d2 = self.ca2(d2) * d2
d2 = self.sa2(d2) * d2
d3 = self.conv3(features[2])
d3 = self.ca3(d3) * d3
d3 = self.sa3(d3) * d3
d4 = self.conv4(features[3])
d4 = self.ca4(d4) * d4
d4 = self.sa4(d4) * d4
d5 = self.conv5(features[4])
d5 = self.ca5(d5) * d5
d5 = self.sa5(d5) * d5
if hasattr(self, 'final_conv'):
output = self.final_conv(d5)
else:
output = d5
return output
class SSADecoder(nn.Module):
"""
UNET DECODER NETWORK
"""
def __init__(self, isize, nc, nz):
super(SSADecoder, self).__init__()
# 初始转置卷积层
self.init_conv = nn.Sequential(
nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True)
)
# 上采样层 - 调整通道数以匹配特征图大小
self.up1 = UNetUp(512, 512) # 4x4 -> 8x8
self.up2 = UNetUp(512, 256) # 8x8 -> 16x16
self.up3 = UNetUp(256, 128) # 16x16 -> 32x32
self.up4 = UNetUp(128, 64) # 32x32 -> 64x64
# 特征转换层 - 确保输出通道数与上采样层的输入通道数匹配
self.feat_conv1 = nn.Conv2d(2048, 512, 1) # ResNet50 layer4 -> 512
self.feat_conv2 = nn.Conv2d(1024, 256, 1) # ResNet50 layer3 -> 256
self.feat_conv3 = nn.Conv2d(512, 128, 1) # ResNet50 layer2 -> 128
self.feat_conv4 = nn.Conv2d(256, 64, 1) # ResNet50 layer1 -> 64
# 注意力模块
self.attn1 = Self_Attn(512, 1, 'relu') # 512通道
self.attn2 = Self_Attn(256, 1, 'relu') # 256通道
self.attn3 = Self_Attn(128, 1, 'relu') # 128通道
self.attn4 = Self_Attn(64, 1, 'relu') # 64通道
# 最终输出层
self.final = nn.Sequential(
nn.Upsample(scale_factor=2), # 64x64 -> 128x128
nn.Conv2d(64, nc, 3, 1, 1),
nn.Tanh()
)
# 注意力输出层
self.attn_out = nn.Sequential(
nn.Upsample(scale_factor=2), # 64x64 -> 128x128
nn.Conv2d(64, 1, 1),
nn.Sigmoid()
)
def forward(self, x, features):
# 初始特征
x = self.init_conv(x) # 4x4
# 特征转换
feat1 = self.feat_conv1(features[4]) # layer4 -> 512
feat2 = self.feat_conv2(features[3]) # layer3 -> 256
feat3 = self.feat_conv3(features[2]) # layer2 -> 128
feat4 = self.feat_conv4(features[1]) # layer1 -> 64
# 上采样路径
x = self.up1(x, feat1) # 4x4 -> 8x8
x, _ = self.attn1(x)
x = self.up2(x, feat2) # 8x8 -> 16x16
x, _ = self.attn2(x)
x = self.up3(x, feat3) # 16x16 -> 32x32
x, _ = self.attn3(x)
x = self.up4(x, feat4) # 32x32 -> 64x64
x, _ = self.attn4(x)
# 生成图像和注意力图
gen_imag = self.final(x) # 64x64 -> 128x128
gen_attn = self.attn_out(x) # 64x64 -> 128x128
return gen_imag, gen_attn
class SSANetG(nn.Module):
"""
GENERATOR NETWORK
"""
def __init__(self, isize=32, nc=3, nz=100, ngf=64, ndf=64, ngpu=1, extralayers=0):
super(SSANetG, self).__init__()
self.isize = isize
self.nc = nc
self.nz = nz
self.ngf = ngf
self.ndf = ndf
self.ngpu = ngpu
self.extralayers = extralayers
self.encoder1 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers)
self.decoder = SSADecoder(self.isize, self.nc, self.nz)
self.encoder2 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers)
def forward(self, x):
# 第一次编码
features = self.encoder1.resnet(x)
latent_i = self.encoder1(x)
# 解码
gen_imag, gen_attn = self.decoder(latent_i, features)
# 第二次编码
features2 = self.encoder2.resnet(gen_imag)
latent_o = self.encoder2(gen_imag)
return gen_imag, latent_i, latent_o, gen_attn
class SSANetD(nn.Module):
"""
DISCRIMINATOR NETWORK
"""
def __init__(self, isize=32, nc=3, attn_c=3):
super(SSANetD, self).__init__()
self.resnet = ResNet50FeatureExtractor(pretrained=True)
self.fusion_conv = nn.Conv2d(nc + 3, 3, 1) # 始终拼接3通道attention
self.attn_adjust = nn.Conv2d(1, 3, 1) # 用于1通道attention转3通道
self.conv1 = nn.Conv2d(64, 64, 1)
self.conv2 = nn.Conv2d(256, 128, 1)
self.conv3 = nn.Conv2d(512, 256, 1)
self.conv4 = nn.Conv2d(1024, 512, 1)
self.conv5 = nn.Conv2d(2048, 512, 1)
self.ca1 = ChannelAttention(64)
self.sa1 = SpatialAttention()
self.ca2 = ChannelAttention(128)
self.sa2 = SpatialAttention()
self.ca3 = ChannelAttention(256)
self.sa3 = SpatialAttention()
self.ca4 = ChannelAttention(512)
self.sa4 = SpatialAttention()
self.ca5 = ChannelAttention(512)
self.sa5 = SpatialAttention()
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1, 1),
nn.Sigmoid()
)
def forward(self, x, attention):
# 自动适配attention通道数
if attention.shape[1] == 1:
attention = self.attn_adjust(attention)
elif attention.shape[1] != 3:
raise ValueError(f"attention通道数为{attention.shape[1]},仅支持1或3通道")
fusion = torch.cat((x, attention), 1)
if fusion.shape[1] != self.fusion_conv.in_channels:
raise ValueError(f"fusion_conv输入通道数为{self.fusion_conv.in_channels},但实际输入为{fusion.shape[1]},请检查attention通道数!")
fusion = self.fusion_conv(fusion)
features = self.resnet(fusion)
d1 = self.conv1(features[0])
d1 = self.ca1(d1) * d1
d1 = self.sa1(d1) * d1
d2 = self.conv2(features[1])
d2 = self.ca2(d2) * d2
d2 = self.sa2(d2) * d2
d3 = self.conv3(features[2])
d3 = self.ca3(d3) * d3
d3 = self.sa3(d3) * d3
d4 = self.conv4(features[3])
d4 = self.ca4(d4) * d4
d4 = self.sa4(d4) * d4
d5 = self.conv5(features[4])
d5 = self.ca5(d5) * d5
d5 = self.sa5(d5) * d5
features = d5
classifier = self.classifier(features)
classifier = classifier.view(-1, 1).squeeze(1)
return classifier, features
最新发布