SA-GAN: self-attention 的 pytorch 实现(针对图像)

本文探讨了基于条件的卷积GAN在生成特定类型图像时遇到的挑战,尤其是在处理复杂全局结构和细密纹理方面。提出了自我注意力GAN(SA-GAN)作为解决方案,通过引入全局依赖性来改善生成质量,确保不同位置特征的有效结合。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题

基于条件的卷积GAN 在那些约束较少的类别中生成的图片较好,比如大海,天空等;但是在那些细密纹理,全局结构较强的类别中生成的图片不是很好,如人脸(可能五官不对应),狗(可能狗腿数量有差,或者毛色不协调)。

可能的原因

大部分卷积神经网络都严重依赖于局部感受野,而无法捕捉全局特征。另外,在多次卷积之后,细密的纹理特征逐渐消失。

SA-GAN解决思路

不仅仅依赖于局部特征,也利用全局特征,通过将不同位置的特征图结合起来(转置就可以结合不同位置的特征)。

"""
    Author: Mingle Xu
    Time: 18 Mar, 2019
    Time: 2020-04-08 22:04:07 revise something
"""

import torch
import torch.nn as nn


class self_attention(nn.Module):
    r"""
        Create global dependence.
        Source paper: https://arxiv.org/abs/1805.08318
    """

    def __init__(self, in_channles):
        super(self_attention, self).__init__()
        self.in_channels = in_channles

        self.f = nn.Conv2d(in_channels=in_channles, out_channels=in_channles // 8, kernel_size=1)
        self.g = nn.Conv2d(in_channels=in_channles, out_channels=in_channles // 8, kernel_size=1)
        self.h = nn.Conv2d(in_channels=in_channles, out_channels=in_channles, kernel_size=1)
        self.softmax_ = nn.Softmax(dim=2)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.init_weight(self.f)
        self.init_weight(self.g)
        self.init_weight(self.h)

    def forward(self, x):
        batch_size, channels, height, width = x.size()

        assert channels == self.in_channels

        f = self.f(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # B * (H * W) * C//8
        g = self.g(x).view(batch_size, -1, height * width)  # B * C//8 * (H * W)

        attention = torch.bmm(f, g)  # B * (H * W) * (H * W)
        attention = self.softmax_(attention)

        h = self.h(x).view(batch_size, channels, -1)  # B * C * (H * W)

        self_attention_map = torch.bmm(h, attention).view(batch_size, channels, height, width)  # B * C * H * W

        return self.gamma * self_attention_map + x

    def init_weight(self, conv):
        nn.init.kaiming_uniform_(conv.weight)
        if conv.bias is not None:
            conv.bias.data.zero_()
#---Part1--- SSAnetwork.py # -*- coding: utf-8 -*- """ Created on Sat Aug 27 13:43:52 2022 @author: User """ import torch.nn as nn def weights_init(mod): """ Custom weights initialization called on netG, netD and netE :param m: :return: """ classname = mod.__class__.__name__ if classname.find('Conv') != -1: mod.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: mod.weight.data.normal_(1.0, 0.02) mod.bias.data.fill_(0) # ==================================================Alister add 2022-10-14====================== import torch.nn as nn import torch.nn.functional as F import torch from network.attension import ChannelAttention, SpatialAttention from network.SelfAttension import Self_Attn import torchvision.models as models def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) ############################## # U-NET ############################## # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) ''' class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() self.leakyrelu = nn.LeakyReLU(0.2) self.conv = nn.Conv2d(in_size, out_size, 4, 2, 1) self.ca = ChannelAttention(out_size) self.sa = SpatialAttention() self.norm = nn.InstanceNorm2d(out_size) self.normalize = normalize self.dropout = dropout self.drop = nn.Dropout() #layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1)) #if normalize: #layers.append(nn.InstanceNorm2d(out_size)) #if dropout: #layers.append(nn.Dropout(dropout)) #self.model = nn.Sequential(*layers) def forward(self, x): x = self.leakyrelu(x) xc = self.conv(x) xa = self.ca(xc)*xc xa = self.sa(xa)*xa x = xc + xa if self.normalize: x = self.norm(x) if self.dropout: x = self.drop(x) return x ''' class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() layers = [nn.LeakyReLU(0.2)] layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1)) if normalize: layers.append(nn.InstanceNorm2d(out_size)) if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class UNetUp(nn.Module): def __init__(self, in_size, out_size): super(UNetUp, self).__init__() self.up = nn.Sequential( nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), nn.BatchNorm2d(out_size), nn.ReLU(True) ) # 修改卷积层的输入通道数,考虑拼接后的通道数 self.conv = nn.Sequential( nn.Conv2d(out_size + out_size, out_size, 3, 1, 1, bias=False), # 输入是out_size*2 nn.BatchNorm2d(out_size), nn.ReLU(True) ) def forward(self, x, skip_input): x = self.up(x) # 确保skip_input的尺寸与x匹配 if skip_input.size() != x.size(): skip_input = F.interpolate(skip_input, size=x.size()[2:], mode='bilinear', align_corners=True) x = torch.cat((x, skip_input), 1) x = self.conv(x) return x class ResNet50FeatureExtractor(nn.Module): def __init__(self, pretrained=True): super(ResNet50FeatureExtractor, self).__init__() resnet = models.resnet50(pretrained=pretrained) self.conv1 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool ) self.layer1 = resnet.layer1 # 256 channels self.layer2 = resnet.layer2 # 512 channels self.layer3 = resnet.layer3 # 1024 channels self.layer4 = resnet.layer4 # 2048 channels def forward(self, x): x1 = self.conv1(x) # 1/4 x2 = self.layer1(x1) # 1/4 x3 = self.layer2(x2) # 1/8 x4 = self.layer3(x3) # 1/16 x5 = self.layer4(x4) # 1/32 return [x1, x2, x3, x4, x5] class SSAEncoder(nn.Module): """ UNET ENCODER NETWORK """ def __init__(self, isize=32, nz=100, nc=3, ndf=64, ngpu=1, n_extra_layers=0, add_final_conv=True): super(SSAEncoder, self).__init__() self.isize = isize # 使用ResNet50作为特征提取器 self.resnet = ResNet50FeatureExtractor(pretrained=True) # 特征转换层 self.conv1 = nn.Conv2d(64, 64, 1) # 转换conv1输出 self.conv2 = nn.Conv2d(256, 128, 1) # 转换layer1输出 self.conv3 = nn.Conv2d(512, 256, 1) # 转换layer2输出 self.conv4 = nn.Conv2d(1024, 512, 1) # 转换layer3输出 self.conv5 = nn.Conv2d(2048, 512, 1) # 转换layer4输出 if add_final_conv: self.final_conv = nn.Conv2d(512, nz, 4, 1, 0, bias=False) # 注意力模块 self.ca1 = ChannelAttention(64) self.sa1 = SpatialAttention() self.ca2 = ChannelAttention(128) self.sa2 = SpatialAttention() self.ca3 = ChannelAttention(256) self.sa3 = SpatialAttention() self.ca4 = ChannelAttention(512) self.sa4 = SpatialAttention() self.ca5 = ChannelAttention(512) self.sa5 = SpatialAttention() def forward(self, x): # 获取ResNet50的多尺度特征 features = self.resnet(x) # 转换特征维度并应用注意力机制 d1 = self.conv1(features[0]) d1 = self.ca1(d1) * d1 d1 = self.sa1(d1) * d1 d2 = self.conv2(features[1]) d2 = self.ca2(d2) * d2 d2 = self.sa2(d2) * d2 d3 = self.conv3(features[2]) d3 = self.ca3(d3) * d3 d3 = self.sa3(d3) * d3 d4 = self.conv4(features[3]) d4 = self.ca4(d4) * d4 d4 = self.sa4(d4) * d4 d5 = self.conv5(features[4]) d5 = self.ca5(d5) * d5 d5 = self.sa5(d5) * d5 if hasattr(self, 'final_conv'): output = self.final_conv(d5) else: output = d5 return output class SSADecoder(nn.Module): """ UNET DECODER NETWORK """ def __init__(self, isize, nc, nz): super(SSADecoder, self).__init__() # 初始转置卷积层 self.init_conv = nn.Sequential( nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True) ) # 上采样层 - 调整通道数以匹配特征图大小 self.up1 = UNetUp(512, 512) # 4x4 -> 8x8 self.up2 = UNetUp(512, 256) # 8x8 -> 16x16 self.up3 = UNetUp(256, 128) # 16x16 -> 32x32 self.up4 = UNetUp(128, 64) # 32x32 -> 64x64 # 特征转换层 - 确保输出通道数与上采样层的输入通道数匹配 self.feat_conv1 = nn.Conv2d(2048, 512, 1) # ResNet50 layer4 -> 512 self.feat_conv2 = nn.Conv2d(1024, 256, 1) # ResNet50 layer3 -> 256 self.feat_conv3 = nn.Conv2d(512, 128, 1) # ResNet50 layer2 -> 128 self.feat_conv4 = nn.Conv2d(256, 64, 1) # ResNet50 layer1 -> 64 # 注意力模块 self.attn1 = Self_Attn(512, 1, 'relu') # 512通道 self.attn2 = Self_Attn(256, 1, 'relu') # 256通道 self.attn3 = Self_Attn(128, 1, 'relu') # 128通道 self.attn4 = Self_Attn(64, 1, 'relu') # 64通道 # 最终输出层 self.final = nn.Sequential( nn.Upsample(scale_factor=2), # 64x64 -> 128x128 nn.Conv2d(64, nc, 3, 1, 1), nn.Tanh() ) # 注意力输出层 self.attn_out = nn.Sequential( nn.Upsample(scale_factor=2), # 64x64 -> 128x128 nn.Conv2d(64, 1, 1), nn.Sigmoid() ) def forward(self, x, features): # 初始特征 x = self.init_conv(x) # 4x4 # 特征转换 feat1 = self.feat_conv1(features[4]) # layer4 -> 512 feat2 = self.feat_conv2(features[3]) # layer3 -> 256 feat3 = self.feat_conv3(features[2]) # layer2 -> 128 feat4 = self.feat_conv4(features[1]) # layer1 -> 64 # 上采样路径 x = self.up1(x, feat1) # 4x4 -> 8x8 x, _ = self.attn1(x) x = self.up2(x, feat2) # 8x8 -> 16x16 x, _ = self.attn2(x) x = self.up3(x, feat3) # 16x16 -> 32x32 x, _ = self.attn3(x) x = self.up4(x, feat4) # 32x32 -> 64x64 x, _ = self.attn4(x) # 生成图像和注意力图 gen_imag = self.final(x) # 64x64 -> 128x128 gen_attn = self.attn_out(x) # 64x64 -> 128x128 return gen_imag, gen_attn class SSANetG(nn.Module): """ GENERATOR NETWORK """ def __init__(self, isize=32, nc=3, nz=100, ngf=64, ndf=64, ngpu=1, extralayers=0): super(SSANetG, self).__init__() self.isize = isize self.nc = nc self.nz = nz self.ngf = ngf self.ndf = ndf self.ngpu = ngpu self.extralayers = extralayers self.encoder1 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers) self.decoder = SSADecoder(self.isize, self.nc, self.nz) self.encoder2 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers) def forward(self, x): # 第一次编码 features = self.encoder1.resnet(x) latent_i = self.encoder1(x) # 解码 gen_imag, gen_attn = self.decoder(latent_i, features) # 第二次编码 features2 = self.encoder2.resnet(gen_imag) latent_o = self.encoder2(gen_imag) return gen_imag, latent_i, latent_o, gen_attn class SSANetD(nn.Module): """ DISCRIMINATOR NETWORK """ def __init__(self, isize=32, nc=3, attn_c=3): super(SSANetD, self).__init__() self.resnet = ResNet50FeatureExtractor(pretrained=True) self.fusion_conv = nn.Conv2d(nc + 3, 3, 1) # 始终拼接3通道attention self.attn_adjust = nn.Conv2d(1, 3, 1) # 用于1通道attention转3通道 self.conv1 = nn.Conv2d(64, 64, 1) self.conv2 = nn.Conv2d(256, 128, 1) self.conv3 = nn.Conv2d(512, 256, 1) self.conv4 = nn.Conv2d(1024, 512, 1) self.conv5 = nn.Conv2d(2048, 512, 1) self.ca1 = ChannelAttention(64) self.sa1 = SpatialAttention() self.ca2 = ChannelAttention(128) self.sa2 = SpatialAttention() self.ca3 = ChannelAttention(256) self.sa3 = SpatialAttention() self.ca4 = ChannelAttention(512) self.sa4 = SpatialAttention() self.ca5 = ChannelAttention(512) self.sa5 = SpatialAttention() self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1, 1), nn.Sigmoid() ) def forward(self, x, attention): # 自动适配attention通道数 if attention.shape[1] == 1: attention = self.attn_adjust(attention) elif attention.shape[1] != 3: raise ValueError(f"attention通道数为{attention.shape[1]},仅支持1或3通道") fusion = torch.cat((x, attention), 1) if fusion.shape[1] != self.fusion_conv.in_channels: raise ValueError(f"fusion_conv输入通道数为{self.fusion_conv.in_channels},但实际输入为{fusion.shape[1]},请检查attention通道数!") fusion = self.fusion_conv(fusion) features = self.resnet(fusion) d1 = self.conv1(features[0]) d1 = self.ca1(d1) * d1 d1 = self.sa1(d1) * d1 d2 = self.conv2(features[1]) d2 = self.ca2(d2) * d2 d2 = self.sa2(d2) * d2 d3 = self.conv3(features[2]) d3 = self.ca3(d3) * d3 d3 = self.sa3(d3) * d3 d4 = self.conv4(features[3]) d4 = self.ca4(d4) * d4 d4 = self.sa4(d4) * d4 d5 = self.conv5(features[4]) d5 = self.ca5(d5) * d5 d5 = self.sa5(d5) * d5 features = d5 classifier = self.classifier(features) classifier = classifier.view(-1, 1).squeeze(1) return classifier, features
最新发布
08-09
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值