Swin空间窗口

 原:

class AesSA(nn.Module):
    def __init__(self, in_planes):
        super(AesSA, self).__init__()

        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)

    def mean_variance_norm(self, feat):
        size = feat.size()
        mean, std = calc_mean_std(feat)
        normalized_feat = (feat - mean.expand(size)) / std.expand(size)
        return normalized_feat

    def forward(self, content, style):

        content_key = self.mean_variance_norm(content)

        F = self.f(content_key)
        G = self.g(style)
        H = self.h(style)
        b, _, h_g, w_g = G.size()
        # 使用 contiguous 函数确保张量在内存中是连续存储的
        G = G.view(b, -1, w_g * h_g).contiguous()
        # .transpose(1, 2):这一步是对 view 后的张量进行转置操作,将第1维和第2维进行交换
        style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
        b, _, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        # print("F",F.shape)
        S = torch.bmm(F, G)  # content_key,style_key批矩阵乘法操作
        S = self.sm(S)  # 将一组数转换为概率分布,每个元素转换为非负值,并且所有元素的和为1
        # print("S",S.shape)
        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()

        return std * self.mean_variance_norm(content) + mean

 改: 复原的是风格图内容

class AesSA(nn.Module):
    def __init__(self, in_planes):
        super(AesSA, self).__init__()

        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)

        self.window_size = 4

    def mean_variance_norm(self, feat):
        size = feat.size()
        mean, std = calc_mean_std(feat)
        normalized_feat = (feat - mean.expand(size)) / std.expand(size)
        return normalized_feat

    def forward(self, content, style):
        def window_partition(x, window_size):
            B, C, H, W = x.shape
            x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
            windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
            return windows

        def window_reverse(windows, window_size, hw):
            B = int(windows.shape[0] / (hw[0] * hw[1] / window_size / window_size))
            x = windows.view(B, hw[0] // window_size, hw[1] // window_size, -1, window_size, window_size)
            x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, hw[0], hw[1])
            return x

        B, C, H_c, W_c = content.shape
        content_key = self.mean_variance_norm(content)
        c_win = content_windows = window_partition(content, self.window_size)
        content_windows = window_partition(content_key, self.window_size)
        style_windows = window_partition(style, self.window_size)

        F = self.f(content_windows)
        G = self.g(style_windows)
        H = self.h(style_windows)
        b, _, h_g, w_g = G.size()
        # 使用 contiguous 函数确保张量在内存中是连续存储的
        G = G.view(b, -1, w_g * h_g).contiguous()
        # .transpose(1, 2):这一步是对 view 后的张量进行转置操作,将第1维和第2维进行交换
        style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
        b, _, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        # print("F",F.shape)
        S = torch.bmm(F, G)  # content_key,style_key批矩阵乘法操作
        S = self.sm(S)  # 将一组数转换为概率分布,每个元素转换为非负值,并且所有元素的和为1
        # print("S",S.shape)
        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()

        O = std * self.mean_variance_norm(c_win) + mean

        out = window_reverse(O, self.window_size, (H_c, W_c))

        return out

改2:栅格形状结果

class AesSA(nn.Module):
    def __init__(self, in_planes, window_size=8, shift_size=4):
        super(AesSA, self).__init__()
        self.window_size = window_size
        self.shift_size = shift_size

        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)

    def mean_variance_norm(self, feat):
        size = feat.size()
        mean, std = calc_mean_std(feat)
        normalized_feat = (feat - mean.expand(size)) / std.expand(size)
        return normalized_feat

    def window_partition(self, feat):
        b, c, h, w = feat.shape
        feat_windows = feat.view(b, c, h // self.window_size, self.window_size,
                                  w // self.window_size, self.window_size)
        feat_windows = feat_windows.permute(0, 2, 4, 1, 5, 3).contiguous()
        feat_windows = feat_windows.view(-1, c, self.window_size, self.window_size)
        return feat_windows

    def window_reverse(self, windows, b, h, w):
        windows = windows.view(-1, self.window_size, self.window_size,
                                h // self.window_size, w // self.window_size)
        windows = windows.permute(0, 3, 4, 1, 2).contiguous().view(b, -1, h, w)
        return windows

    def forward(self, content, style):
        content_key = self.mean_variance_norm(content)

        F = self.f(content_key)
        G = self.g(style)
        H = self.h(style)

        b, _, h, w = F.size()
        F_windows = self.window_partition(F)
        G_windows = self.window_partition(G)
        # print(F_windows.shape)
        # print(G_windows.shape)
        F_windows = F_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).permute(0, 2, 1)
        G_windows = G_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).contiguous()
        # print(F_windows.shape)
        # print(G_windows.shape)
        S = torch.bmm(F_windows, G_windows)
        S = self.sm(S)

        H_windows = self.window_partition(H)
        style_flat = H_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).transpose(1, 2).contiguous()

        mean = torch.bmm(S, style_flat)
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))

        mean = self.window_reverse(mean, b, h, w)
        std = self.window_reverse(std, b, h, w)

        return std * self.mean_variance_norm(content) + mean

改3:

class AesSA(nn.Module):
    def __init__(self, in_planes, window_size=8, shift_size=4):
        super(AesSA, self).__init__()
        self.window_size = window_size
        self.shift_size = shift_size

        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)

    def mean_variance_norm(self, feat):
        size = feat.size()
        mean, std = calc_mean_std(feat)
        normalized_feat = (feat - mean.expand(size)) / std.expand(size)
        return normalized_feat

    def window_partition(self, feat):
        b, c, h, w = feat.shape
        feat_windows = feat.view(b, c, h // self.window_size, self.window_size,
                                  w // self.window_size, self.window_size)
        feat_windows = feat_windows.permute(0, 2, 4, 1, 5, 3).contiguous()
        feat_windows = feat_windows.view(-1, c, self.window_size, self.window_size)
        return feat_windows

    def window_reverse(self, windows, b, h, w):
        windows = windows.view(-1, self.window_size, self.window_size,
                                h // self.window_size, w // self.window_size)
        windows = windows.permute(0, 3, 4, 1, 2).contiguous().view(b, -1, h, w)
        return windows

    def forward(self, content, style):
        content_key = self.mean_variance_norm(content)

        F = self.f(content_key)
        G = self.g(style)
        H = self.h(style)

        b, _, h, w = F.size()
        F_windows = self.window_partition(F)
        G_windows = self.window_partition(G)
        # print(F_windows.shape)
        # print(G_windows.shape)
        b_w, _, h_w, w_w = F_windows.size()
        F_windows = F_windows.view(b_w, -1, h_w * w_w).permute(0, 2, 1)
        G_windows = G_windows.view(b_w, -1, h_w * w_w).contiguous()
        # print(F_windows.shape)  # torch.Size([256, 256, 64])  torch.Size([1024, 64, 64])
        # print(G_windows.shape)
        S = torch.bmm(F_windows, G_windows)
        S = self.sm(S)

        H_windows = self.window_partition(H)
        style_flat = H_windows.view(b_w, -1, h_w * w_w).transpose(1, 2).contiguous()

        mean = torch.bmm(S, style_flat)
        # print(mean.size())
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        # print(std.size())

        mean = self.window_reverse(mean, b, h, w)
        std = self.window_reverse(std, b, h, w)

        return std * self.mean_variance_norm(content) + mean

 改:呈现的是风格图效果,认为是adaattn的问题

import torch
import torch.nn as nn
import torch.nn.functional as F

class AesSA(nn.Module):
    def __init__(self, in_planes, window_size=7):
        super(AesSA, self).__init__()
        self.window_size = window_size

        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)
    
    def mean_variance_norm(self, feat):
        size = feat.size()
        mean, std = calc_mean_std(feat)
        normalized_feat = (feat - mean.expand(size)) / std.expand(size)
        return normalized_feat
    
    def window_partition(self, x, window_size):
        B, C, H, W = x.shape
        x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
        windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        windows = windows.view(-1, C, window_size, window_size)
        return windows

    def window_reverse(self, windows, window_size, H, W):
        B = int(windows.shape[0] / (H * W / window_size / window_size))
        x = windows.view(B, H // window_size, W // window_size, -1, window_size, window_size)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(B, -1, H, W)
        return x
    
    def forward(self, content, style):
        B, C, H, W = content.shape

        # Normalize content key
        content_key = self.mean_variance_norm(content)
        
        # Apply conv layers
        F = self.f(content_key)
        G = self.g(style)
        H = self.h(style)

        # Partition into windows
        F_windows = self.window_partition(F, self.window_size)
        G_windows = self.window_partition(G, self.window_size)
        H_windows = self.window_partition(H, self.window_size)
        
        # Compute attention within windows
        F_windows_flat = F_windows.view(-1, C, self.window_size * self.window_size).permute(0, 2, 1)
        G_windows_flat = G_windows.view(-1, C, self.window_size * self.window_size)
        S = torch.bmm(F_windows_flat, G_windows_flat)
        S = self.sm(S)
        
        H_windows_flat = H_windows.view(-1, C, self.window_size * self.window_size).permute(0, 2, 1)
        mean = torch.bmm(S, H_windows_flat)
        std = torch.sqrt(torch.relu(torch.bmm(S, H_windows_flat ** 2) - mean ** 2))
        
        mean = mean.view(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2).contiguous()
        std = std.view(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2).contiguous()
        
        # Reverse windows to original shape
        mean = self.window_reverse(mean, self.window_size, H, W)
        std = self.window_reverse(std, self.window_size, H, W)
        
        return std * self.mean_variance_norm(content) + mean

改:

import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionUnit(nn.Module):
    def __init__(self, channels, window_size=7):
        super(AttentionUnit, self).__init__()
        self.window_size = window_size

        self.relu6 = nn.ReLU6()
        self.f = nn.Conv2d(channels, channels // 2, (1, 1))
        self.g = nn.Conv2d(channels, channels // 2, (1, 1))
        self.h = nn.Conv2d(channels, channels // 2, (1, 1))

        self.out_conv = nn.Conv2d(channels // 2, channels, (1, 1))
        self.softmax = nn.Softmax(dim=-1)

        self.fuse = FuseUnit(channels)
        self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1)
        self.pad = nn.ReflectionPad2d((1, 1, 1, 1))

    def mean_variance_norm(self, feat):
        size = feat.size()
        mean = feat.mean(dim=(2, 3), keepdim=True)
        var = feat.var(dim=(2, 3), unbiased=False, keepdim=True)
        return (feat - mean) / torch.sqrt(var + 1e-5)

    def window_partition(self, x, window_size):
        B, C, H, W = x.shape
        x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
        windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        windows = windows.view(-1, C, window_size, window_size)
        return windows

    def window_reverse(self, windows, window_size, H, W):
        B = int(windows.shape[0] / (H * W / window_size / window_size))
        x = windows.view(B, H // window_size, W // window_size, -1, window_size, window_size)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(B, -1, H, W)
        return x

    def forward(self, Fc, Fs):
        B, C, H, W = Fc.shape

        # Normalize features
        f_Fc = self.relu6(self.f(self.mean_variance_norm(Fc)))
        g_Fs = self.relu6(self.g(self.mean_variance_norm(Fs)))
        h_Fs = self.relu6(self.h(self.mean_variance_norm(Fs)))

        # Partition into windows
        f_Fc_windows = self.window_partition(f_Fc, self.window_size)
        g_Fs_windows = self.window_partition(g_Fs, self.window_size)
        h_Fs_windows = self.window_partition(h_Fs, self.window_size)

        # Compute attention within windows
        f_Fc_windows_flat = f_Fc_windows.view(-1, C // 2, self.window_size * self.window_size).permute(0, 2, 1)
        g_Fs_windows_flat = g_Fs_windows.view(-1, C // 2, self.window_size * self.window_size)
        Attention = self.softmax(torch.bmm(f_Fc_windows_flat, g_Fs_windows_flat))

        # Aggregate features using attention
        h_Fs_windows_flat = h_Fs_windows.view(-1, C // 2, self.window_size * self.window_size).permute(0, 2, 1)
        Fcs = torch.bmm(h_Fs_windows_flat, Attention.permute(0, 2, 1))
        Fcs = Fcs.view(B, C // 2, H, W)

        # Apply convolutions
        Fcs = self.relu6(self.out_conv(Fcs))
        Fcs = self.relu6(self.conv_out(self.pad(Fcs)))
        Fcs = self.fuse(Fc, Fcs)

        return Fcs

# 保存中间结果 Save intermediate results
        output_dir = Path('samples')
        output_dir.mkdir(exist_ok=True, parents=True)

        if (img_index + 1) % 100 == 0:
            print('[%d/%d] loss_content:%.4f, loss_gan_d:%.4f' % (
            img_index + 1, args.stage1_iter + args.stage2_iter, loss.item(), loss_gan_d.item()))
            visualized_imgs = torch.cat([Ic, Is, Ics])
            output_name = output_dir / 'output{:d}.jpg'.format(img_index + 1)
            save_image(visualized_imgs, str(output_name), nrow=args.batch_size)
            mes = "iteration: " + str(img_index + 1) + " loss: " + str(loss.sum().item())
            logging.info(mes)
            model.save_ckpts()
            adjust_learning_rate(optimizer, img_index, args)
g_t = self.decoder(stylized)
        g_t_feats = self.encode_with_intermediate(g_t)

        # 旋转内容特征图
        rotated_content = self.rotate_tensor_rhw_90(content)
        rotated_content_feats = self.encode_with_intermediate(rotated_content)
        # Ics_geo
        geo_stylized = self.transform(rotated_content_feats[3], style_feats[3])
        geo_g_t = self.decoder(geo_stylized)
        geo_g_t = self.rotate_tensor_ccw_90(geo_g_t)
        geo_g_t_feats = self.encode_with_intermediate(geo_g_t)

        loss_geo = 0.3 * (self.calc_content_loss(g_t_feats[3], geo_g_t_feats[3], norm=True) + self.calc_content_loss(g_t_feats[4], geo_g_t_feats[4], norm=True))

        del rotated_content, rotated_content_feats, geo_stylized, geo_g_t_feats
        torch.cuda.empty_cache()

 门控

class SANet(nn.Module):
    def __init__(self, in_planes):
        super(SANet, self).__init__()
        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))

        # 用于生成内容特征的门控信号
        self.content_gate = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // 2, (1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_planes // 2, in_planes, (1, 1)),  # 输出通道应与内容特征通道数一致
            nn.Sigmoid()  # 确保门控值在0到1之间
        )

        self.sm = nn.Softmax(dim=-1)
        self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))

    def forward(self, content, style):
        F = self.f(mean_variance_norm(content))  # 内容特征
        G = self.g(mean_variance_norm(style))  # 风格特征
        H = self.h(style)  # 风格特征

        # 计算内容特征的门控信号
        content_gate_signal = self.content_gate(content)

        b, c, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        G = G.view(b, -1, w * h)

        S = torch.bmm(F, G)
        S = self.sm(S)

        # 应用门控信号到风格特征H之前
        # 通过逐元素乘法让内容特征的“重要性”控制风格特征的传递
        gated_H = H * content_gate_signal

        b, c, h, w = gated_H.size()
        gated_H = gated_H.view(b, -1, w * h)

        # 用门控后的风格特征进行特征融合
        O = torch.bmm(gated_H, S.permute(0, 2, 1))
        O = O.view(b, c, h, w)
        O = self.out_conv(O)

        # 可以选择性地保留原始内容特征的部分信息
        O += content

        return O

PSNR 

import os
import cv2
import numpy as np

import matplotlib
matplotlib.use('TKAgg') #
import cv2
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch

# python Score.py
# content_folder = "../inputs/content" # 更改为内容图片所在的文件夹路径
content_folder = "../inputs/style" # 更改为内容图片所在的文件夹路径
output_folder = "./output" # 更改为风格化图片所在的文件夹路径


psnr_scores = []
stylized_files = os.listdir(output_folder)

for stylized in stylized_files:
    stylized_img = cv2.imread(output_folder + '/'+ stylized)  # stylized image
    name = stylized.split("_stylized_")  # parse the content image's name
    content_img = cv2.imread(content_folder + '/' + name[0] + '.jpg')   # content image

    print(output_folder + '/'+ stylized + "   " + content_folder + '/' + name[0] + '.jpg')

    # 计算PSNR指标
    scorepsnr = psnr(stylized_img, content_img)
    psnr_scores.append(scorepsnr)

# 计算平均分数
avg_scorepsnr = np.mean(psnr_scores)

print("Average PSNR score for {} image pairs: {:.4f}".format(len(stylized_files), avg_scorepsnr))

SSIM损失

import torch
from skimage.metrics import structural_similarity as ssim
import torch.nn.functional as F

def ssim_loss(tensor_a, tensor_b, data_range=None, size_average=True, win_size=11, win_sigma=1.5, K=(0.01, 0.03)):
    """
    计算两个图像张量之间的负SSIM损失。
    
    :param tensor_a: 第一个图像张量
    :param tensor_b: 第二个图像张量
    :param data_range: 数据范围,默认为None,此时自动从数据中推断
    :param size_average: 是否平均损失,默认为True
    :param win_size: SSIM窗口大小,默认为11
    :param win_sigma: SSIM高斯核的标准差,默认为1.5
    :param K: SSIM常数项,默认为(0.01, 0.03)
    :return: 负SSIM损失值
    """
    # 将张量转换为适合计算SSIM的形式(例如,从TensorFlow到NumPy)
    img_a_np = tensor_a.cpu().numpy().transpose((1, 2, 0)) if isinstance(tensor_a, torch.Tensor) else tensor_a
    img_b_np = tensor_b.cpu().numpy().transpose((1, 2, 0)) if isinstance(tensor_b, torch.Tensor) else tensor_b
    
    # 计算SSIM
    ssim_val = ssim(img_a_np, img_b_np, data_range=data_range, win_size=win_size, win_sigma=win_sigma, K=K)
    
    # 转换为负值作为损失函数
    loss = -ssim_val
    
    # 如果需要,对所有像素或 batch 平均损失
    if size_average:
        loss = loss.mean()
    
    return loss

边缘算子

import kornia as K
import torchvision.transforms as T

self.smooth_l1_loss = torch.nn.SmoothL1Loss()
        self.unnormalizeTensor = T.Compose([T.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
                                                        std=[1 / 0.229, 1 / 0.224, 1 / 0.225])])

structure_loss = torch.tensor(0.0).cuda()
        sobel_img = K.filters.sobel(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))))
        structure_loss += self.smooth_l1_loss(sobel_img, K.filters.sobel(
            K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t))))
        laplacian_img = K.filters.laplacian(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))),
                                            kernel_size=5)
        structure_loss += self.smooth_l1_loss(laplacian_img, K.filters.laplacian(
            K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t)), kernel_size=5))
        canny_img = K.filters.canny(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))))[0]
        structure_loss += self.smooth_l1_loss(canny_img, K.filters.canny(
            K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t)))[0])
        structure_loss *= 5.0
        del sobel_img, laplacian_img, canny_img

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值