import torch
import torch.nn as nn
import torch.nn.functional as F
# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):
# vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))
# if torch.cuda.is_available():
# vgg_mean = vgg_mean.cuda()
# vgg_input = input_255-vgg_mean
# #x = vgg_model.features[0](vgg_input)
# #FeatureMap_list.append(x)
# for i in range(0,layer_index+1):
# if i == 0:
# x = vgg_model.features[0](vgg_input)
# else:
# x = vgg_model.features[i](x)
# return x
def l_num_loss(img1, img2, l_num=1):
"""计算L1损失。
参数:
img1: 第一个输入图像tensor。
img2: 第二个输入图像tensor。
l_num: Lp范数的阶数,默认为1,即L1损失。
返回:
L1损失的均值。
"""
return torch.mean(torch.abs((img1 - img2)**l_num))
def boundary_extraction(mask):
"""提取图像边界。
参数:
mask: 输入的二值掩码,0表示非边界,1表示边界。
返回:
返回一个显示边界的图像,非边界的位置为0,边界的位置为输入mask的值。
"""
ones = torch.ones_like(mask)
zeros = torch.zeros_like(mask)
# 定义卷积核,3x3的全1核
in_channel = 1
out_channel = 1
kernel = [[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]
kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3) # 用3×3卷积核检测边界
if torch.cuda.is_available():
kernel = kernel.cuda()
ones = ones.cuda()
zeros = zeros.cuda()
weight = nn.Parameter(data=kernel, requires_grad=False)
#dilation
# 多次卷积加强边界,将小于1的像素设置为0,大于等于1的像素设置为1,强调边界位置
# 使用卷积核进行膨胀操作,多次卷积操作加强边界
x = F.conv2d(1-mask,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
# 返回结果像素值为1的是边界,像素值为0的非边界。
return x*mask
# 边界损失由公式(10)
def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):
"""计算边界损失。
参数:
inpu1_tesnor: 第一个输入图像tensor。
inpu2_tesnor: 第二个输入图像tensor。
mask1_tesnor: 第一个输入图像的掩码。
mask2_tesnor: 第二个输入图像的掩码。
stitched_image: 拼接后的图像。
返回:
total_loss: 总的边界损失。
boundary_mask1: 第一个图像的边界掩码。
"""
# 提取完的mask中的像素是0或1,乘完得到边界的mask,不为0的像素是边界
boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)
boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)
# 两个边界用L1范数计算损失
loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)
loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)
# 返回总的边界项损失和边界mask1,对应论文公式(10)
return loss1+loss2, boundary_mask1
# 计算平滑项损失,惩罚拼接图中不平滑区域,使结果更自然
# 对应论文中公式(12)
def cal_smooth_term_stitch(stitched_image, learned_mask1):
"""计算拼接图的平滑损失。
参数:
stitched_image: 拼接后的图像。
learned_mask1: 学习得到的掩码。
返回:
loss: 平滑损失的值。
"""
delta = 1
# 计算掩码在垂直和水平方向的变化量
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:]) # 计算learned_mask1两个方向的变化量
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
# 计算拼接图像在两个方向的差异
dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:]) # 计算stitched_image两个方向的差值
dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img #相乘得到差异
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# 平滑项差异损失
def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):
"""计算图像内容的平滑差异损失。
参数:
img1: 第一个输入图像tensor。
img2: 第二个输入图像tensor。
learned_mask1: 学习到的掩码。
overlap: 重叠区域的掩码。
返回:
loss: 平滑项差异损失的值。
"""
diff_feature = torch.abs(img1-img2)**2 * overlap # 两个图像在重叠区域差异
delta = 1
# 在掩码中计算变化量
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
# 计算图像差异在两个方向上的变化
dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])
dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])
# 计算最终损失
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# dh_zeros = torch.zeros_like(dh_pixel)
# dw_zeros = torch.zeros_like(dw_pixel)
# if torch.cuda.is_available():
# dh_zeros = dh_zeros.cuda()
# dw_zeros = dw_zeros.cuda()
# loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)
# return loss, dh_pixel以上是我的损失函数
最新发布