import torch
import torch.nn as nn
import torch.nn.functional as F
def build_model(net, warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor):
out = net(warp1_tensor, warp2_tensor, mask1_tensor, mask2_tensor)
# 更新学习到的掩码
learned_mask1 = (mask1_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*out
learned_mask2 = (mask2_tensor - mask1_tensor*mask2_tensor) + mask1_tensor*mask2_tensor*(1-out)
# 生成拼接图像,根据学习到的掩码加权融合输入图像
stitched_image = (warp1_tensor+1.) * learned_mask1 + (warp2_tensor+1.)*learned_mask2 - 1.
out_dict = {}
out_dict.update(learned_mask1=learned_mask1, learned_mask2=learned_mask2, stitched_image = stitched_image)
return out_dict
class DownBlock(nn.Module):
def __init__(self, inchannels, outchannels, dilation, pool=True):
super(DownBlock, self).__init__()
blk = []
if pool:
blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 可选的最大池化层
blk.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation)) # 卷积层
blk.append(nn.ReLU(inplace=True)) # 激活层
blk.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation)) # 卷积层
blk.append(nn.ReLU(inplace=True))
self.layer = nn.Sequential(*blk) # 将块组合成一个顺序容器
# 初始化卷积层权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight) # Kaiming正态分布初始化
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1) # 确保BatchNorm层的初始化
m.bias.data.zero_()
def forward(self, x):
return self.layer(x) # 前向传播
class UpBlock(nn.Module):
def __init__(self, inchannels, outchannels, dilation):
super(UpBlock, self).__init__()
#self.convt = nn.ConvTranspose2d(inchannels, outchannels, kernel_size=2, stride=2)
self.halfChanelConv = nn.Sequential(
nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.conv = nn.Sequential(
nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
nn.ReLU(inplace=True),
nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1, dilation = dilation),
nn.ReLU(inplace=True)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x1, x2):
x1 = F.interpolate(x1, size = (x2.size()[2], x2.size()[3]), mode='nearest')
x1 = self.halfChanelConv(x1)
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
# predict the composition mask of img1
class Network(nn.Module):
def __init__(self, nclasses=1):
super(Network, self).__init__()
self.down1 = DownBlock(3, 32, 1, pool=False)
self.down2 = DownBlock(32, 64, 2)
self.down3 = DownBlock(64, 128,3)
self.down4 = DownBlock(128, 256, 4)
self.down5 = DownBlock(256, 512, 5)
self.up1 = UpBlock(512, 256, 4)
self.up2 = UpBlock(256, 128, 3)
self.up3 = UpBlock(128, 64, 2)
self.up4 = UpBlock(64, 32, 1)
self.out = nn.Sequential(
nn.Conv2d(32, nclasses, kernel_size=1),
nn.Sigmoid()
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x, y, m1, m2):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
y1 = self.down1(y)
y2 = self.down2(y1)
y3 = self.down3(y2)
y4 = self.down4(y3)
y5 = self.down5(y4)
res = self.up1(x5-y5, x4-y4)
res = self.up2(res, x3-y3)
res = self.up3(res, x2-y2)
res = self.up4(res, x1-y1)
res = self.out(res)
return res
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torchvision.models as models
# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):
# vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))
# if torch.cuda.is_available():
# vgg_mean = vgg_mean.cuda()
# vgg_input = input_255-vgg_mean
# #x = vgg_model.features[0](vgg_input)
# #FeatureMap_list.append(x)
# for i in range(0,layer_index+1):
# if i == 0:
# x = vgg_model.features[0](vgg_input)
# else:
# x = vgg_model.features[i](x)
# return x
def l_num_loss(img1, img2, l_num=1):
"""计算L1损失。
参数:
img1: 第一个输入图像tensor。
img2: 第二个输入图像tensor。
l_num: Lp范数的阶数,默认为1,即L1损失。
返回:
L1损失的均值。
"""
return torch.mean(torch.abs((img1 - img2)**l_num))
def boundary_extraction(mask):
"""提取图像边界。
参数:
mask: 输入的二值掩码,0表示非边界,1表示边界。
返回:
返回一个显示边界的图像,非边界的位置为0,边界的位置为输入mask的值。
"""
ones = torch.ones_like(mask)
zeros = torch.zeros_like(mask)
# 定义卷积核,3x3的全1核
in_channel = 1
out_channel = 1
kernel = [[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]
kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3) # 用3×3卷积核检测边界
if torch.cuda.is_available():
kernel = kernel.cuda()
ones = ones.cuda()
zeros = zeros.cuda()
weight = nn.Parameter(data=kernel, requires_grad=False)
#dilation
# 多次卷积加强边界,将小于1的像素设置为0,大于等于1的像素设置为1,强调边界位置
# 使用卷积核进行膨胀操作,多次卷积操作加强边界
x = F.conv2d(1-mask,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
# 返回结果像素值为1的是边界,像素值为0的非边界。
return x*mask
# 边界损失由公式(10)
def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):
"""计算边界损失。
参数:
inpu1_tesnor: 第一个输入图像tensor。
inpu2_tesnor: 第二个输入图像tensor。
mask1_tesnor: 第一个输入图像的掩码。
mask2_tesnor: 第二个输入图像的掩码。
stitched_image: 拼接后的图像。
返回:
total_loss: 总的边界损失。
boundary_mask1: 第一个图像的边界掩码。
"""
# 提取完的mask中的像素是0或1,乘完得到边界的mask,不为0的像素是边界
boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)
boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)
# 两个边界用L1范数计算损失
loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)
loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)
# 返回总的边界项损失和边界mask1,对应论文公式(10)
return loss1+loss2, boundary_mask1
# 计算平滑项损失,惩罚拼接图中不平滑区域,使结果更自然
# 对应论文中公式(12)
def cal_smooth_term_stitch(stitched_image, learned_mask1):
"""计算拼接图的平滑损失。
参数:
stitched_image: 拼接后的图像。
learned_mask1: 学习得到的掩码。
返回:
loss: 平滑损失的值。
"""
delta = 1
# 计算掩码在垂直和水平方向的变化量
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:]) # 计算learned_mask1两个方向的变化量
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
# 计算拼接图像在两个方向的差异
dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:]) # 计算stitched_image两个方向的差值
dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img #相乘得到差异
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# 平滑项差异损失
def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):
"""计算图像内容的平滑差异损失。
参数:
img1: 第一个输入图像tensor。
img2: 第二个输入图像tensor。
learned_mask1: 学习到的掩码。
overlap: 重叠区域的掩码。
返回:
loss: 平滑项差异损失的值。
"""
diff_feature = torch.abs(img1-img2)**2 * overlap # 两个图像在重叠区域差异
delta = 1
# 在掩码中计算变化量
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
# 计算图像差异在两个方向上的变化
dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])
dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])
# 计算最终损失
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# dh_zeros = torch.zeros_like(dh_pixel)
# dw_zeros = torch.zeros_like(dw_pixel)
# if torch.cuda.is_available():
# dh_zeros = dh_zeros.cuda()
# dw_zeros = dw_zeros.cuda()
# loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)
# return loss, dh_pixel
class FeatureConsistencyLoss(nn.Module):
"""VGG特征一致性损失(增强边界约束)"""
def __init__(self, layer_indices=[3, 8, 15], alpha=0.5, use_cuda=True):
super().__init__()
# 加载预训练VGG(固定权重)
self.vgg = self._load_vgg()
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
self.layers = layer_indices
self.alpha = alpha # 特征损失权重
# VGG标准化参数
self.register_buffer('mean', torch.tensor([123.68, 116.779, 103.939]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1))
# 注册特征提取层
self.feature_extractor = self._build_feature_extractor()
# 设备兼容
if use_cuda and torch.cuda.is_available():
self.mean = self.mean.cuda()
self.std = self.std.cuda()
self.feature_extractor = self.feature_extractor.cuda()
self.vgg = self.vgg.cuda()
def _load_vgg(self):
"""加载VGG19模型"""
try:
# 尝试从环境变量或默认位置加载
torch_home = os.environ.get('TORCH_HOME', os.path.expanduser('~/.cache/torch'))
model_path = os.path.join(torch_home, 'checkpoints', 'vgg19-dcbb9e9d.pth')
# 如果存在本地缓存,直接加载
if os.path.exists(model_path):
print(f"从本地缓存加载VGG19模型: {model_path}")
vgg = models.vgg19(pretrained=False)
state_dict = torch.load(model_path)
vgg.load_state_dict(state_dict)
return vgg.features
else:
# 否则下载新模型
print("下载VGG19模型...")
vgg = models.vgg19(pretrained=True)
# 确保目录存在
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# 保存下载的模型
torch.save(vgg.state_dict(), model_path)
print(f"模型已保存到: {model_path}")
return vgg.features
# 添加对特定错误的更详细处理
except (FileNotFoundError, PermissionError) as e:
print(f"加载VGG模型出错: {e}")
# 回退方案1:使用更轻量的模型
print("尝试使用MobileNetV2作为备选")
return models.mobilenet_v2(pretrained=True).features
except Exception as e:
print(f"未知错误: {e}")
# 回退方案2:构建简单模型
print("使用小型替代模型")
return nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU()
)
def _build_feature_extractor(self):
"""构建特征提取器模块"""
# 使用模块列表存储所有层
feature_extractor = nn.ModuleList()
# 遍历VGG特征层
for name, module in self.vgg.named_children():
feature_extractor.append(module)
return feature_extractor
def _get_features(self, x, extractor):
"""提取特征图同时收集中间层特征"""
features = []
for module in extractor:
x = module(x)
features.append(x)
return features
def forward(self, pred_img, target_img, boundary_mask):
"""计算特征一致性损失"""
# 输入预处理 ([-1,1] -> VGG输入空间)
pred_img = (pred_img + 1) * 127.5 # [-1,1] -> [0,255]
target_img = (target_img + 1) * 127.5
# 标准化 (VGG输入要求)
pred_img = (pred_img - self.mean) / self.std
target_img = (target_img - self.mean) / self.std
# 确保边界掩码是单通道 [batch, 1, H, W]
if boundary_mask.size(1) == 3:
# 如果是三通道,取平均转换为单通道
boundary_mask = torch.mean(boundary_mask, dim=1, keepdim=True)
# 提取特征
with torch.no_grad():
# 显式收集所有中间特征
features_pred = self._get_features(pred_img, self.feature_extractor)
features_target = self._get_features(target_img, self.feature_extractor)
# 计算多层特征的损失
loss = 0
current_alpha = self.alpha
# 计算每个指定层的损失
for layer_idx in self.layers:
# 确保索引在有效范围内
if layer_idx >= len(features_pred) or layer_idx >= len(features_target):
continue
feat_pred = features_pred[layer_idx]
feat_target = features_target[layer_idx]
# 调整边界掩码尺寸到当前特征图的空间尺寸
downsampled_mask = F.interpolate(
boundary_mask,
size=feat_pred.shape[2:],
mode='bilinear',
align_corners=False
)
# 确保掩码与特征图维度匹配(通道数扩展)
# 关键修复:使用expand而不是expand_as
downsampled_mask = downsampled_mask.expand(
-1,
feat_pred.size(1), # 特征图通道数
-1,
-1
)
# 计算特征差异(仅边界区域)
diff = torch.sum(
(feat_pred - feat_target) ** 2 * downsampled_mask
) / (downsampled_mask.sum() + 1e-6)
# 加权损失(权重随层深度递减)
loss += diff * current_alpha
current_alpha *= 0.5
return loss
请你参考我上面的network.py代码和损失函数代码,我现在在尝试用无监督训练的方式拼接两张图像,由于我两张图片的重叠度很高为上下对齐,所以导致了在边界上我有无穷多边界交点,所以接缝的起点和终点为第二张图片的上两个顶点在拼接图像的位置,会导致我靠近边界特征物体拼接效果不太好,我该怎么解决这个问题呢?