原:
class AesSA(nn.Module):
def __init__(self, in_planes):
super(AesSA, self).__init__()
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
def mean_variance_norm(self, feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def forward(self, content, style):
content_key = self.mean_variance_norm(content)
F = self.f(content_key)
G = self.g(style)
H = self.h(style)
b, _, h_g, w_g = G.size()
# 使用 contiguous 函数确保张量在内存中是连续存储的
G = G.view(b, -1, w_g * h_g).contiguous()
# .transpose(1, 2):这一步是对 view 后的张量进行转置操作,将第1维和第2维进行交换
style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
b, _, h, w = F.size()
F = F.view(b, -1, w * h).permute(0, 2, 1)
# print("F",F.shape)
S = torch.bmm(F, G) # content_key,style_key批矩阵乘法操作
S = self.sm(S) # 将一组数转换为概率分布,每个元素转换为非负值,并且所有元素的和为1
# print("S",S.shape)
# mean: b, n_c, c
mean = torch.bmm(S, style_flat)
std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
return std * self.mean_variance_norm(content) + mean
改: 复原的是风格图内容
class AesSA(nn.Module):
def __init__(self, in_planes):
super(AesSA, self).__init__()
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
self.window_size = 4
def mean_variance_norm(self, feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def forward(self, content, style):
def window_partition(x, window_size):
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
return windows
def window_reverse(windows, window_size, hw):
B = int(windows.shape[0] / (hw[0] * hw[1] / window_size / window_size))
x = windows.view(B, hw[0] // window_size, hw[1] // window_size, -1, window_size, window_size)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, -1, hw[0], hw[1])
return x
B, C, H_c, W_c = content.shape
content_key = self.mean_variance_norm(content)
c_win = content_windows = window_partition(content, self.window_size)
content_windows = window_partition(content_key, self.window_size)
style_windows = window_partition(style, self.window_size)
F = self.f(content_windows)
G = self.g(style_windows)
H = self.h(style_windows)
b, _, h_g, w_g = G.size()
# 使用 contiguous 函数确保张量在内存中是连续存储的
G = G.view(b, -1, w_g * h_g).contiguous()
# .transpose(1, 2):这一步是对 view 后的张量进行转置操作,将第1维和第2维进行交换
style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
b, _, h, w = F.size()
F = F.view(b, -1, w * h).permute(0, 2, 1)
# print("F",F.shape)
S = torch.bmm(F, G) # content_key,style_key批矩阵乘法操作
S = self.sm(S) # 将一组数转换为概率分布,每个元素转换为非负值,并且所有元素的和为1
# print("S",S.shape)
# mean: b, n_c, c
mean = torch.bmm(S, style_flat)
std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
O = std * self.mean_variance_norm(c_win) + mean
out = window_reverse(O, self.window_size, (H_c, W_c))
return out
改2:栅格形状结果
class AesSA(nn.Module):
def __init__(self, in_planes, window_size=8, shift_size=4):
super(AesSA, self).__init__()
self.window_size = window_size
self.shift_size = shift_size
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
def mean_variance_norm(self, feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def window_partition(self, feat):
b, c, h, w = feat.shape
feat_windows = feat.view(b, c, h // self.window_size, self.window_size,
w // self.window_size, self.window_size)
feat_windows = feat_windows.permute(0, 2, 4, 1, 5, 3).contiguous()
feat_windows = feat_windows.view(-1, c, self.window_size, self.window_size)
return feat_windows
def window_reverse(self, windows, b, h, w):
windows = windows.view(-1, self.window_size, self.window_size,
h // self.window_size, w // self.window_size)
windows = windows.permute(0, 3, 4, 1, 2).contiguous().view(b, -1, h, w)
return windows
def forward(self, content, style):
content_key = self.mean_variance_norm(content)
F = self.f(content_key)
G = self.g(style)
H = self.h(style)
b, _, h, w = F.size()
F_windows = self.window_partition(F)
G_windows = self.window_partition(G)
# print(F_windows.shape)
# print(G_windows.shape)
F_windows = F_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).permute(0, 2, 1)
G_windows = G_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).contiguous()
# print(F_windows.shape)
# print(G_windows.shape)
S = torch.bmm(F_windows, G_windows)
S = self.sm(S)
H_windows = self.window_partition(H)
style_flat = H_windows.view(-1, self.window_size * self.window_size, h // self.window_size * w // self.window_size).transpose(1, 2).contiguous()
mean = torch.bmm(S, style_flat)
std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
mean = self.window_reverse(mean, b, h, w)
std = self.window_reverse(std, b, h, w)
return std * self.mean_variance_norm(content) + mean
改3:
class AesSA(nn.Module):
def __init__(self, in_planes, window_size=8, shift_size=4):
super(AesSA, self).__init__()
self.window_size = window_size
self.shift_size = shift_size
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
def mean_variance_norm(self, feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def window_partition(self, feat):
b, c, h, w = feat.shape
feat_windows = feat.view(b, c, h // self.window_size, self.window_size,
w // self.window_size, self.window_size)
feat_windows = feat_windows.permute(0, 2, 4, 1, 5, 3).contiguous()
feat_windows = feat_windows.view(-1, c, self.window_size, self.window_size)
return feat_windows
def window_reverse(self, windows, b, h, w):
windows = windows.view(-1, self.window_size, self.window_size,
h // self.window_size, w // self.window_size)
windows = windows.permute(0, 3, 4, 1, 2).contiguous().view(b, -1, h, w)
return windows
def forward(self, content, style):
content_key = self.mean_variance_norm(content)
F = self.f(content_key)
G = self.g(style)
H = self.h(style)
b, _, h, w = F.size()
F_windows = self.window_partition(F)
G_windows = self.window_partition(G)
# print(F_windows.shape)
# print(G_windows.shape)
b_w, _, h_w, w_w = F_windows.size()
F_windows = F_windows.view(b_w, -1, h_w * w_w).permute(0, 2, 1)
G_windows = G_windows.view(b_w, -1, h_w * w_w).contiguous()
# print(F_windows.shape) # torch.Size([256, 256, 64]) torch.Size([1024, 64, 64])
# print(G_windows.shape)
S = torch.bmm(F_windows, G_windows)
S = self.sm(S)
H_windows = self.window_partition(H)
style_flat = H_windows.view(b_w, -1, h_w * w_w).transpose(1, 2).contiguous()
mean = torch.bmm(S, style_flat)
# print(mean.size())
std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
# print(std.size())
mean = self.window_reverse(mean, b, h, w)
std = self.window_reverse(std, b, h, w)
return std * self.mean_variance_norm(content) + mean
改:呈现的是风格图效果,认为是adaattn的问题
import torch
import torch.nn as nn
import torch.nn.functional as F
class AesSA(nn.Module):
def __init__(self, in_planes, window_size=7):
super(AesSA, self).__init__()
self.window_size = window_size
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
def mean_variance_norm(self, feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def window_partition(self, x, window_size):
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
windows = windows.view(-1, C, window_size, window_size)
return windows
def window_reverse(self, windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, -1, window_size, window_size)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
x = x.view(B, -1, H, W)
return x
def forward(self, content, style):
B, C, H, W = content.shape
# Normalize content key
content_key = self.mean_variance_norm(content)
# Apply conv layers
F = self.f(content_key)
G = self.g(style)
H = self.h(style)
# Partition into windows
F_windows = self.window_partition(F, self.window_size)
G_windows = self.window_partition(G, self.window_size)
H_windows = self.window_partition(H, self.window_size)
# Compute attention within windows
F_windows_flat = F_windows.view(-1, C, self.window_size * self.window_size).permute(0, 2, 1)
G_windows_flat = G_windows.view(-1, C, self.window_size * self.window_size)
S = torch.bmm(F_windows_flat, G_windows_flat)
S = self.sm(S)
H_windows_flat = H_windows.view(-1, C, self.window_size * self.window_size).permute(0, 2, 1)
mean = torch.bmm(S, H_windows_flat)
std = torch.sqrt(torch.relu(torch.bmm(S, H_windows_flat ** 2) - mean ** 2))
mean = mean.view(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2).contiguous()
std = std.view(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2).contiguous()
# Reverse windows to original shape
mean = self.window_reverse(mean, self.window_size, H, W)
std = self.window_reverse(std, self.window_size, H, W)
return std * self.mean_variance_norm(content) + mean
改:
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionUnit(nn.Module):
def __init__(self, channels, window_size=7):
super(AttentionUnit, self).__init__()
self.window_size = window_size
self.relu6 = nn.ReLU6()
self.f = nn.Conv2d(channels, channels // 2, (1, 1))
self.g = nn.Conv2d(channels, channels // 2, (1, 1))
self.h = nn.Conv2d(channels, channels // 2, (1, 1))
self.out_conv = nn.Conv2d(channels // 2, channels, (1, 1))
self.softmax = nn.Softmax(dim=-1)
self.fuse = FuseUnit(channels)
self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1)
self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
def mean_variance_norm(self, feat):
size = feat.size()
mean = feat.mean(dim=(2, 3), keepdim=True)
var = feat.var(dim=(2, 3), unbiased=False, keepdim=True)
return (feat - mean) / torch.sqrt(var + 1e-5)
def window_partition(self, x, window_size):
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous()
windows = windows.view(-1, C, window_size, window_size)
return windows
def window_reverse(self, windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, -1, window_size, window_size)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
x = x.view(B, -1, H, W)
return x
def forward(self, Fc, Fs):
B, C, H, W = Fc.shape
# Normalize features
f_Fc = self.relu6(self.f(self.mean_variance_norm(Fc)))
g_Fs = self.relu6(self.g(self.mean_variance_norm(Fs)))
h_Fs = self.relu6(self.h(self.mean_variance_norm(Fs)))
# Partition into windows
f_Fc_windows = self.window_partition(f_Fc, self.window_size)
g_Fs_windows = self.window_partition(g_Fs, self.window_size)
h_Fs_windows = self.window_partition(h_Fs, self.window_size)
# Compute attention within windows
f_Fc_windows_flat = f_Fc_windows.view(-1, C // 2, self.window_size * self.window_size).permute(0, 2, 1)
g_Fs_windows_flat = g_Fs_windows.view(-1, C // 2, self.window_size * self.window_size)
Attention = self.softmax(torch.bmm(f_Fc_windows_flat, g_Fs_windows_flat))
# Aggregate features using attention
h_Fs_windows_flat = h_Fs_windows.view(-1, C // 2, self.window_size * self.window_size).permute(0, 2, 1)
Fcs = torch.bmm(h_Fs_windows_flat, Attention.permute(0, 2, 1))
Fcs = Fcs.view(B, C // 2, H, W)
# Apply convolutions
Fcs = self.relu6(self.out_conv(Fcs))
Fcs = self.relu6(self.conv_out(self.pad(Fcs)))
Fcs = self.fuse(Fc, Fcs)
return Fcs
# 保存中间结果 Save intermediate results
output_dir = Path('samples')
output_dir.mkdir(exist_ok=True, parents=True)
if (img_index + 1) % 100 == 0:
print('[%d/%d] loss_content:%.4f, loss_gan_d:%.4f' % (
img_index + 1, args.stage1_iter + args.stage2_iter, loss.item(), loss_gan_d.item()))
visualized_imgs = torch.cat([Ic, Is, Ics])
output_name = output_dir / 'output{:d}.jpg'.format(img_index + 1)
save_image(visualized_imgs, str(output_name), nrow=args.batch_size)
mes = "iteration: " + str(img_index + 1) + " loss: " + str(loss.sum().item())
logging.info(mes)
model.save_ckpts()
adjust_learning_rate(optimizer, img_index, args)
g_t = self.decoder(stylized)
g_t_feats = self.encode_with_intermediate(g_t)
# 旋转内容特征图
rotated_content = self.rotate_tensor_rhw_90(content)
rotated_content_feats = self.encode_with_intermediate(rotated_content)
# Ics_geo
geo_stylized = self.transform(rotated_content_feats[3], style_feats[3])
geo_g_t = self.decoder(geo_stylized)
geo_g_t = self.rotate_tensor_ccw_90(geo_g_t)
geo_g_t_feats = self.encode_with_intermediate(geo_g_t)
loss_geo = 0.3 * (self.calc_content_loss(g_t_feats[3], geo_g_t_feats[3], norm=True) + self.calc_content_loss(g_t_feats[4], geo_g_t_feats[4], norm=True))
del rotated_content, rotated_content_feats, geo_stylized, geo_g_t_feats
torch.cuda.empty_cache()
门控
class SANet(nn.Module):
def __init__(self, in_planes):
super(SANet, self).__init__()
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
# 用于生成内容特征的门控信号
self.content_gate = nn.Sequential(
nn.Conv2d(in_planes, in_planes // 2, (1, 1)),
nn.ReLU(),
nn.Conv2d(in_planes // 2, in_planes, (1, 1)), # 输出通道应与内容特征通道数一致
nn.Sigmoid() # 确保门控值在0到1之间
)
self.sm = nn.Softmax(dim=-1)
self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))
def forward(self, content, style):
F = self.f(mean_variance_norm(content)) # 内容特征
G = self.g(mean_variance_norm(style)) # 风格特征
H = self.h(style) # 风格特征
# 计算内容特征的门控信号
content_gate_signal = self.content_gate(content)
b, c, h, w = F.size()
F = F.view(b, -1, w * h).permute(0, 2, 1)
G = G.view(b, -1, w * h)
S = torch.bmm(F, G)
S = self.sm(S)
# 应用门控信号到风格特征H之前
# 通过逐元素乘法让内容特征的“重要性”控制风格特征的传递
gated_H = H * content_gate_signal
b, c, h, w = gated_H.size()
gated_H = gated_H.view(b, -1, w * h)
# 用门控后的风格特征进行特征融合
O = torch.bmm(gated_H, S.permute(0, 2, 1))
O = O.view(b, c, h, w)
O = self.out_conv(O)
# 可以选择性地保留原始内容特征的部分信息
O += content
return O
PSNR
import os
import cv2
import numpy as np
import matplotlib
matplotlib.use('TKAgg') #
import cv2
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch
# python Score.py
# content_folder = "../inputs/content" # 更改为内容图片所在的文件夹路径
content_folder = "../inputs/style" # 更改为内容图片所在的文件夹路径
output_folder = "./output" # 更改为风格化图片所在的文件夹路径
psnr_scores = []
stylized_files = os.listdir(output_folder)
for stylized in stylized_files:
stylized_img = cv2.imread(output_folder + '/'+ stylized) # stylized image
name = stylized.split("_stylized_") # parse the content image's name
content_img = cv2.imread(content_folder + '/' + name[0] + '.jpg') # content image
print(output_folder + '/'+ stylized + " " + content_folder + '/' + name[0] + '.jpg')
# 计算PSNR指标
scorepsnr = psnr(stylized_img, content_img)
psnr_scores.append(scorepsnr)
# 计算平均分数
avg_scorepsnr = np.mean(psnr_scores)
print("Average PSNR score for {} image pairs: {:.4f}".format(len(stylized_files), avg_scorepsnr))
SSIM损失
import torch
from skimage.metrics import structural_similarity as ssim
import torch.nn.functional as F
def ssim_loss(tensor_a, tensor_b, data_range=None, size_average=True, win_size=11, win_sigma=1.5, K=(0.01, 0.03)):
"""
计算两个图像张量之间的负SSIM损失。
:param tensor_a: 第一个图像张量
:param tensor_b: 第二个图像张量
:param data_range: 数据范围,默认为None,此时自动从数据中推断
:param size_average: 是否平均损失,默认为True
:param win_size: SSIM窗口大小,默认为11
:param win_sigma: SSIM高斯核的标准差,默认为1.5
:param K: SSIM常数项,默认为(0.01, 0.03)
:return: 负SSIM损失值
"""
# 将张量转换为适合计算SSIM的形式(例如,从TensorFlow到NumPy)
img_a_np = tensor_a.cpu().numpy().transpose((1, 2, 0)) if isinstance(tensor_a, torch.Tensor) else tensor_a
img_b_np = tensor_b.cpu().numpy().transpose((1, 2, 0)) if isinstance(tensor_b, torch.Tensor) else tensor_b
# 计算SSIM
ssim_val = ssim(img_a_np, img_b_np, data_range=data_range, win_size=win_size, win_sigma=win_sigma, K=K)
# 转换为负值作为损失函数
loss = -ssim_val
# 如果需要,对所有像素或 batch 平均损失
if size_average:
loss = loss.mean()
return loss
边缘算子
import kornia as K
import torchvision.transforms as T
self.smooth_l1_loss = torch.nn.SmoothL1Loss()
self.unnormalizeTensor = T.Compose([T.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225])])
structure_loss = torch.tensor(0.0).cuda()
sobel_img = K.filters.sobel(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))))
structure_loss += self.smooth_l1_loss(sobel_img, K.filters.sobel(
K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t))))
laplacian_img = K.filters.laplacian(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))),
kernel_size=5)
structure_loss += self.smooth_l1_loss(laplacian_img, K.filters.laplacian(
K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t)), kernel_size=5))
canny_img = K.filters.canny(K.color.rgb_to_grayscale(self.unnormalizeTensor(content.squeeze(0))))[0]
structure_loss += self.smooth_l1_loss(canny_img, K.filters.canny(
K.color.rgb_to_grayscale(self.unnormalizeTensor(g_t)))[0])
structure_loss *= 5.0
del sobel_img, laplacian_img, canny_img