import torch
import torch.nn as nn
import torch.nn.functional as F
from torchgan.layers import SpectralNorm2d
from ssim import msssim
from vggloss_1 import VGGLoss
import torchvision.models as models
import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import argparse
import lpips
from CBAM import CBAM
import lpips
class SelfAttention(nn.Module): #diff: 添加自注意力模块
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B x N x C'
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B x C' x N
energy = torch.bmm(proj_query, proj_key) # B x N x N
attention = self.softmax(energy) # B x N x N
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B x C x N
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B x C x N
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1),
nn.Tanh()
)
class inConv(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(inConv, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.InstanceNorm2d(out_channels))
class Sub_Res_down(nn.Module):
def __init__(self, in_channels, out_channels):
super(Sub_Res_down, self).__init__()
self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
nn.InstanceNorm2d(out_channels),
nn.Mish(inplace=True),
nn.Dropout(0.1))
self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
nn.InstanceNorm2d(out_channels),
nn.Dropout(0.1))
# self.cbam = CBAM(out_channels, 8, 7)
self.cbam = CoordAtt(out_channels, out_channels)
self.relu = nn.Mish(inplace=True)
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.InstanceNorm2d(out_channels))
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.cbam(out)
out += self.shortcut(residual)
out = self.relu(out)
out = self.maxpool(out)
return out
class Sub_Res_up(nn.Module):
def __init__(self, in_channels, out_channels):
super(Sub_Res_up, self).__init__()
self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
nn.InstanceNorm2d(out_channels),
nn.Mish(inplace=True),
nn.Dropout(0.1))
self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
nn.InstanceNorm2d(out_channels),
nn.Dropout(0.1))
# self.cbam = CBAM(out_channels, 8, 7)
self.cbam = CoordAtt(out_channels, out_channels)
self.relu = nn.Mish(inplace=True)
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.InstanceNorm2d(out_channels))
self.ConvT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
x = self.ConvT(x)
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.cbam(out)
out += self.shortcut(residual)
out = self.relu(out)
return out
class ResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
nn.GroupNorm(32,out_channels),
nn.Mish(inplace=True),
nn.Dropout(0.1))
self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
nn.Dropout(0.1))
# self.cbam = CBAM(out_channels, 8, 7)
self.cbam = CoordAtt(out_channels, out_channels)
self.relu = nn.Mish(inplace=True)
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.GroupNorm(32, out_channels)
)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.cbam(out)
out += self.shortcut(residual)
out = self.relu(out)
return out
class Gen(nn.Module):
def __init__(self, in_channels=8, out_channels=4):
super(Gen, self).__init__()
self.down_1 = Sub_Res_down(2, 64)
self.down_2 = Sub_Res_down(64, 128)
self.up_1 = Sub_Res_up(128, 64)
self.up_2 = Sub_Res_up(64, 32)
self.OutConv = OutConv(32, 4)
self.OutConv_1 = OutConv(64,4)
self.InConv = inConv(2,32)
# encoder
self.conv1 = ResNetBlock(in_channels, 64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = ResNetBlock(64, 128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = ResNetBlock(128, 256)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = ResNetBlock(256, 512)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
# center
self.center = ResNetBlock(512, 1024)
# decoder
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.conv_decode4 = ResNetBlock(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv_decode3 = ResNetBlock(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv_decode2 = ResNetBlock(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv_decode1 = ResNetBlock(128, 64)
self.attn = SelfAttention(512) # diff: 添加自注意力模块
def forward(self, a, b, c):
x_0 = c - a
x_1 = self.down_1(x_0)
x_2 = self.down_2(x_1)
x_3 = self.up_1(x_2)
x_4 = self.up_2(x_3)
x_r = self.InConv(x_0)
x_5 = x_r+x_4
x_6 = self.OutConv(x_5)
y = torch.cat([x_6, b], dim=1)
# encoder
conv1 = self.conv1(y)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
pool4 = self.pool4(conv4)
attn_output = self.attn(pool4)
# center
center = self.center(attn_output)
# decoder
up4 = self.up4(center)
concat4 = torch.cat([up4, conv4], dim=1)
conv_decode4 = self.conv_decode4(concat4)
up3 = self.up3(conv_decode4)
concat3 = torch.cat([up3, conv3], dim=1)
conv_decode3 = self.conv_decode3(concat3)
up2 = self.up2(conv_decode3)
concat2 = torch.cat([up2, conv2], dim=1)
conv_decode2 = self.conv_decode2(concat2)
up1 = self.up1(conv_decode2)
concat1 = torch.cat([up1, conv1], dim=1)
conv_decode1 = self.conv_decode1(concat1)
# output
output = self.OutConv_1(conv_decode1)
return x_6, output
class ReconstructionLoss(nn.Module):
def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, g=1.0):
super(ReconstructionLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.vggloss = VGGLoss(4)
def forward(self, prediction, target):
loss = (self.alpha * (self.vggloss(prediction, target)) +
self.gamma * (1.0 - torch.mean(F.cosine_similarity(prediction, target, 1))) +
self.beta * (1.0 - msssim(prediction, target, normalize=True)))
return loss
class ResidulBlockWithSpectralNorm_1(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidulBlockWithSpectralNorm_1, self).__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.Mish(),
SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 2, 1)),
nn.BatchNorm2d(in_channels),
nn.Mish(),
SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1))
)
self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1))
def forward(self, inputs):
return self.transform(inputs) + self.residual(inputs)
class ResidulBlockWithSpectralNorm_2(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidulBlockWithSpectralNorm_2, self).__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.Mish(),
SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 1, 1)),
nn.BatchNorm2d(in_channels),
nn.Mish(),
SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)),
)
self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1))
def forward(self, inputs):
return self.transform(inputs) + self.residual(inputs)
class Discriminator(nn.Sequential):
def __init__(self, channels):
modules = []
for i in range(1, (len(channels)-1)):
modules.append(ResidulBlockWithSpectralNorm_1(channels[i - 1], channels[i]))
modules.append(nn.Sequential(ResidulBlockWithSpectralNorm_2(channels[-2], channels[-1]),
ResidulBlockWithSpectralNorm_2(channels[-1], 1),
nn.Sigmoid()))
super(Discriminator, self).__init__(*modules)
def forward(self, inputs):
prediction = super(Discriminator, self).forward(inputs)
# return prediction.view(-1, 1).squeeze(1)
return prediction
class MSDiscriminator(nn.Module):
def __init__(self):
super(MSDiscriminator, self).__init__()
self.d1 = Discriminator((12, 64, 128, 256,512))
self.d2 = Discriminator((12, 128, 256,512))
self.d3 = Discriminator((12, 256,512))
def forward(self, inputs):
l1 = self.d1(inputs)
l2 = self.d2(F.interpolate(inputs, scale_factor=0.5))
l3 = self.d3(F.interpolate(inputs, scale_factor=0.25))
L = l1+l2+l3
# return torch.mean(torch.stack((l1, l2, l3)))
return L
model_MTS2ONet
于 2025-01-16 20:46:27 首次发布
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

被折叠的 条评论
为什么被折叠?



