model_MTS2ONet

部署运行你感兴趣的模型镜像
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchgan.layers import SpectralNorm2d

from ssim import msssim

from vggloss_1 import VGGLoss
import torchvision.models as models

import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import argparse
import lpips

from CBAM import CBAM
import lpips


class SelfAttention(nn.Module):  #diff: 添加自注意力模块
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)  # B x N x C'
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)  # B x C' x N
        energy = torch.bmm(proj_query, proj_key)  # B x N x N
        attention = self.softmax(energy)  # B x N x N
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)  # B x C x N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # B x C x N
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out





class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1),
            nn.Tanh()
        )

class inConv(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(inConv, self).__init__(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                                     nn.InstanceNorm2d(out_channels))


class Sub_Res_down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_down, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)
        out = self.maxpool(out)

        return out

class Sub_Res_up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Sub_Res_up, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.InstanceNorm2d(out_channels),
                                   nn.Dropout(0.1))

        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.InstanceNorm2d(out_channels))

        self.ConvT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):

        x = self.ConvT(x)

        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False),
                                   nn.GroupNorm(32,out_channels),
                                   nn.Mish(inplace=True),
                                   nn.Dropout(0.1))
        self.conv2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=3, bias=False),
                                   nn.Dropout(0.1))
        # self.cbam = CBAM(out_channels, 8, 7)
        self.cbam = CoordAtt(out_channels, out_channels)
        self.relu = nn.Mish(inplace=True)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.GroupNorm(32, out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)

        out = self.conv2(out)

        out = self.cbam(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out






class Gen(nn.Module):
    def __init__(self, in_channels=8, out_channels=4):
        super(Gen, self).__init__()

        self.down_1 = Sub_Res_down(2, 64)
        self.down_2 = Sub_Res_down(64, 128)
        self.up_1 = Sub_Res_up(128, 64)
        self.up_2 = Sub_Res_up(64, 32)

        self.OutConv = OutConv(32, 4)
        self.OutConv_1 = OutConv(64,4)
        self.InConv = inConv(2,32)



        # encoder
        self.conv1 = ResNetBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = ResNetBlock(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = ResNetBlock(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = ResNetBlock(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # center
        self.center = ResNetBlock(512, 1024)

        # decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv_decode4 = ResNetBlock(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_decode3 = ResNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_decode2 = ResNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_decode1 = ResNetBlock(128, 64)

        self.attn = SelfAttention(512)  # diff: 添加自注意力模块


    def forward(self, a, b, c):
        x_0 = c - a
        x_1 = self.down_1(x_0)
        x_2 = self.down_2(x_1)
        x_3 = self.up_1(x_2)
        x_4 = self.up_2(x_3)
        x_r = self.InConv(x_0)
        x_5 = x_r+x_4
        x_6 = self.OutConv(x_5)

        y = torch.cat([x_6, b], dim=1)






        # encoder
        conv1 = self.conv1(y)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)

        attn_output = self.attn(pool4)
        
        # center
        center = self.center(attn_output)

        # decoder
        up4 = self.up4(center)
        concat4 = torch.cat([up4, conv4], dim=1)
        conv_decode4 = self.conv_decode4(concat4)
        up3 = self.up3(conv_decode4)
        concat3 = torch.cat([up3, conv3], dim=1)
        conv_decode3 = self.conv_decode3(concat3)
        up2 = self.up2(conv_decode3)
        concat2 = torch.cat([up2, conv2], dim=1)
        conv_decode2 = self.conv_decode2(concat2)
        up1 = self.up1(conv_decode2)
        concat1 = torch.cat([up1, conv1], dim=1)
        conv_decode1 = self.conv_decode1(concat1)

        # output
        output = self.OutConv_1(conv_decode1)


        return x_6, output


class ReconstructionLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, g=1.0):
        super(ReconstructionLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.vggloss = VGGLoss(4)

    def forward(self, prediction, target):
        loss = (self.alpha * (self.vggloss(prediction, target)) +
                self.gamma * (1.0 - torch.mean(F.cosine_similarity(prediction, target, 1))) +
                self.beta * (1.0 - msssim(prediction, target, normalize=True)))
        return loss


class ResidulBlockWithSpectralNorm_1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_1, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 2, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1))
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)

class ResidulBlockWithSpectralNorm_2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidulBlockWithSpectralNorm_2, self).__init__()
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, in_channels, 4, 1, 1)),
            nn.BatchNorm2d(in_channels),
            nn.Mish(),
            SpectralNorm2d(nn.Conv2d(in_channels, out_channels, 1)),
        )
        self.transform = SpectralNorm2d(nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=1, padding=1))

    def forward(self, inputs):
        return self.transform(inputs) + self.residual(inputs)


class Discriminator(nn.Sequential):
    def __init__(self, channels):
        modules = []
        for i in range(1, (len(channels)-1)):
            modules.append(ResidulBlockWithSpectralNorm_1(channels[i - 1], channels[i]))

        modules.append(nn.Sequential(ResidulBlockWithSpectralNorm_2(channels[-2], channels[-1]),
                       ResidulBlockWithSpectralNorm_2(channels[-1], 1),
                       nn.Sigmoid()))

        super(Discriminator, self).__init__(*modules)


    def forward(self, inputs):
        prediction = super(Discriminator, self).forward(inputs)
        # return prediction.view(-1, 1).squeeze(1)

        return prediction



class MSDiscriminator(nn.Module):
    def __init__(self):
        super(MSDiscriminator, self).__init__()
        self.d1 = Discriminator((12, 64, 128, 256,512))
        self.d2 = Discriminator((12, 128, 256,512))
        self.d3 = Discriminator((12, 256,512))

    def forward(self, inputs):
        l1 = self.d1(inputs)
        l2 = self.d2(F.interpolate(inputs, scale_factor=0.5))
        l3 = self.d3(F.interpolate(inputs, scale_factor=0.25))
        L = l1+l2+l3
        # return torch.mean(torch.stack((l1, l2, l3)))
        return L

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

无法解析导入 `Onet_Swin_config_20230408` 通常有以下几种可能原因及对应的解决办法: ### 模块或文件不存在 - **检查文件是否存在**:确认 `Onet_Swin_config_20230408` 对应的文件是否存在于正确的路径下。如果它是一个 Python 模块,应该有对应的 `.py` 文件。例如,如果代码中使用 `import Onet_Swin_config_20230408`,则需要有 `Onet_Swin_config_20230408.py` 文件。 ```python import os file_path = 'path/to/Onet_Swin_config_20230408.py' if os.path.exists(file_path): print("文件存在") else: print("文件不存在,请检查路径") ``` - **检查文件名和大小写**:Python 对文件名的大小写是敏感的,确保文件名的大小写与导入语句中的一致。 ### 路径问题 - **检查 Python 路径**:如果 `Onet_Swin_config_20230408` 所在的目录不在 Python 的搜索路径中,Python 将无法找到它。可以通过以下方式查看 Python 的搜索路径: ```python import sys print(sys.path) ``` - **添加路径**:如果文件所在的目录不在上述路径中,可以临时添加该路径: ```python import sys sys.path.append('path/to/directory_containing_Onet_Swin_config_20230408') import Onet_Swin_config_20230408 ``` 也可以将该目录添加到环境变量 `PYTHONPATH` 中,这样每次启动 Python 时都会自动包含该路径。 ### 包结构问题 - **检查包结构**:如果 `Onet_Swin_config_20230408` 是包的一部分,需要确保包的结构正确。包目录下应该有 `__init__.py` 文件(Python 3.3 及以后版本中,该文件不是必需的,但为了兼容性,建议保留)。例如,如果 `Onet_Swin_config_20230408` 位于 `my_package` 包中,目录结构应该如下: ``` my_package/ __init__.py Onet_Swin_config_20230408.py ``` 导入时可以使用 `from my_package import Onet_Swin_config_20230408`。 ### 依赖问题 - **检查依赖库**:`Onet_Swin_config_20230408` 可能依赖于其他库,如果这些依赖库没有正确安装,可能会导致导入失败。检查该模块的文档,确保所有依赖库都已正确安装。 ### 缓存问题 - **清除 Pyc 缓存**:有时候,Python 的字节码缓存文件(`.pyc`)可能会导致问题。可以尝试删除项目目录下的所有 `.pyc` 文件,然后重新运行代码。在项目根目录下可以使用以下命令删除 `.pyc` 文件: ```bash find . -name "*.pyc" -delete ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值