model_MTS2ONet

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchgan.layers import SpectralNorm2d

from ssim import msssim

from vggloss_1 import VGGLoss
import torchvision.models as models

import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import argparse
import lpips

from CBAM import CBAM
import lpips


class SelfAttention(nn.Module):  #diff: 添加自注意力模块
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B x N x C'
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B x C' x N
        energy = torch.bmm(proj_query, proj_key)  # B x N x N
        attention = self.softmax(energy)  # B x N x N
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B x C x N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B x C x N
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out





class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1),
            nn.Tanh()
        )

class inConv(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(inConv, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                                     nn.InstanceNorm2d(out_channels))


class Sub_Res_down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_down, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)
        out = self.maxpool(out)

        return out

class Sub_Res_up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_up, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))

        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.ConvT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):

        x = self.ConvT(x)

        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.GroupNorm(32,out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.GroupNorm(32, out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out






class Gen(nn.Module):
    def __init__(self, in_channels=8, out_channels=4):
        super(Gen, self).__init__()

        self.down_1 = Sub_Res_down(2, 64)
        self.down_2 = Sub_Res_down(64, 128)
        self.up_1 = Sub_Res_up(128, 64)
        self.up_2 = Sub_Res_up(64, 32)

        self.OutConv = OutConv(32, 4)
        self.OutConv_1 = OutConv(64,4)
        self.InConv = inConv(2,32)



        # encoder
        self.conv1 = ResNetBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = ResNetBlock(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ResNetBlock(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = ResNetBlock(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # center
        self.center = ResNetBlock(512, 1024)

        # decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_decode4 = ResNetBlock(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_decode3 = ResNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_decode2 = ResNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_decode1 = ResNetBlock(128, 64)

        self.attn = SelfAttention(512)  # diff: 添加自注意力模块


    def forward(self, a, b, c):
        x_0 = c - a
        x_1 = self.down_1(x_0)
        x_2 = self.down_2(x_1)
        x_3 = self.up_1(x_2)
        x_4 = self.up_2(x_3)
        x_r = self.InConv(x_0)
        x_5 = x_r+x_4
        x_6 = self.OutConv(x_5)

        y = torch.cat([x_6, b], dim=1)






        # encoder
        conv1 = self.conv1(y)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)

        attn_output = self.attn(pool4)
        
        # center
        center = self.center(attn_output)

        # decoder
        up4 = self.up4(center)
        concat4 = torch.cat([up4, conv4], dim=1)
        conv_decode4 = self.conv_decode4(concat4)
        up3 = self.up3(conv_decode4)
        concat3 = torch.cat([up3, conv3], dim=1)
        conv_decode3 = self.conv_decode3(concat3)
        up2 = self.up2(conv_decode3)
        concat2 = torch.cat([up2, conv2], dim=1)
        conv_decode2 = self.conv_decode2(concat2)
        up1 = self.up1(conv_decode2)
        concat1 = torch.cat([up1, conv1], dim=1)
        conv_decode1 = self.conv_decode1(concat1)

        # output
        output = self.OutConv_1(conv_decode1)


        return x_6, output


class ReconstructionLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, g=1.0):
        super(ReconstructionLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.vggloss = VGGLoss(4)

    def forward(self, prediction, target):
        loss = (self.alpha * (self.vggloss(prediction, target)) +
                self.gamma * (1.0 - torch.mean(F.cosine_similarity(prediction, target, 1))) +
                self.beta * (1.0 - msssim(prediction, target, normalize=True)))
        return loss


class ResidulBlockWithSpectralNorm_1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_1, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 2, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1))
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)

class ResidulBlockWithSpectralNorm_2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_2, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 1, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)),
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)


class Discriminator(nn.Sequential):
    def __init__(self, channels):
        modules = []
        for i in range(1, (len(channels)-1)):
            modules.append(ResidulBlockWithSpectralNorm_1(channels[i - 1], channels[i]))

        modules.append(nn.Sequential(ResidulBlockWithSpectralNorm_2(channels[-2], channels[-1]),
                       ResidulBlockWithSpectralNorm_2(channels[-1], 1),
                       nn.Sigmoid()))

        super(Discriminator, self).__init__(*modules)


    def forward(self, inputs):
        prediction = super(Discriminator, self).forward(inputs)
        # return prediction.view(-1, 1).squeeze(1)

        return prediction



class MSDiscriminator(nn.Module):
    def __init__(self):
        super(MSDiscriminator, self).__init__()
        self.d1 = Discriminator((12, 64, 128, 256,512))
        self.d2 = Discriminator((12, 128, 256,512))
        self.d3 = Discriminator((12, 256,512))

    def forward(self, inputs):
        l1 = self.d1(inputs)
        l2 = self.d2(F.interpolate(inputs, scale_factor=0.5))
        l3 = self.d3(F.interpolate(inputs, scale_factor=0.25))
        L = l1+l2+l3
        # return torch.mean(torch.stack((l1, l2, l3)))
        return L
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值