基于改进LPTN网络的多任务生成模型

项目需求:

开发一个人体肤色均匀功能

现有技术参考:

LPTN生成网络,实现单维度图像的生成

https://github.com/csjliang/LPTN

存在问题:

如何让模型区分不同维度,实现多维度的生成。如:一共有五种肤色选择,如何通过一个模型实现按需生成,希望生成白色或者棕色皮肤等。而非每一个维度训练一个模型,浪费资源的占用。

解决思路:

参考stargan方法,改造LPTN网络的生成器和判别器,将生成器的输入从1个(图像)变为2个(图像加标签)。标签分类和生成图像真伪判别公用一个判别器,除了判断生成图像的真伪外,返回图像的类别。

1.数据准备与加载

1)首先拷贝LPTN公开代码到自己的工程,配置好相关的环境(不再赘述)

2)准备数据集,在datasets/my_data(没有的自己创建)文件夹底下创建train和test两个文件夹,在train和test文件夹底下分别再创建A和B文件夹,A文件夹中存放原始图像数据,B中存放需要生成的目标图像数据(标签图像),注意A和B文件夹底下图像名称需要对应匹配。保证数据中人体皮肤占比大,否则生成效果不理想。

3)如何给每个图像输入对应的标签?拟在读取图像数据时根据每个图像的名称进行标签添加。为每张图像的名称添加对应的标签前缀,如:目前图像是白色的则添加“white_"前缀。前缀添加脚本代码如下:

import os


def add_prefix_to_images(folder_path, prefix):
    # 遍历文件夹
    for filename in os.listdir(folder_path):
        # 检查文件是否为图片(这里简单判断后缀为.jpg)
        if filename.endswith(".jpg") or filename.endswith(".png"):
            # 构建完整的文件路径
            old_path = os.path.join(folder_path, filename)

            # 构建新的文件名,添加指定前缀
            new_filename = f"{prefix}_{filename}"
            new_path = os.path.join(folder_path, new_filename)

            # 重命名文件
            os.rename(old_path, new_path)
            print(f"Renamed: {filename} to {new_filename}")

#图像存储路径
folder_path_A = r'E:\chen\skin_512\test\A'
prefix = "white"
# 调用函数
add_prefix_to_images(folder_path_A, prefix)

4)修改LPTN代码中codes/data/paired_image_dataset.py文件,为数据加载时获取图像的标签,定义extract_label_from_filename方法,根据“_”切割方法,切割出图像的标签,定义label_to_onehot方法,将标签转换为onehot编码,代码如下:

    def extract_label_from_filename(self, filename):
        # Extract label from filename (assuming label is before the first "_")
        label_end = filename.find('_')
        label = filename[:label_end]
        return label

    def label_to_onehot(self, label):
        # Convert label to one-hot encoding
        label_mapping = {'white': 0, 'brown': 1, 'pink': 2, 'blue': 3, 'yellow': 4}  # Add more labels if needed
        onehot = torch.zeros(len(label_mapping))
        if label in label_mapping:
            onehot[label_mapping[label]] = 1
        return onehot

在最后的返回代码中添加onehot标签返回,以便训练时调用:

 return {
            'lq': img_lq,
            'gt': img_gt,
            'lq_path': lq_path,
            'gt_path': gt_path,
            'lq_label': onehot_lq_label,
            'gt_label': onehot_gt_label,
        }

2.生成器和判别器修改

生成器:

修改codes\models\archs中LPTN_arch.py文件

存在问题:在什么地方添加标签合适?可以随意一个位置添加标签还是某个位置添加才能更好的引导模型进行按需生成是一个难点。通过对LPTN神经网络结构的研究,发现图像的颜色特征主要由低维特征决定的,纹理特征主要由高维特征。因此我们可以尝试在低维度中添加标签信息。那么,在低维特征中,一开始就添加好呢还是某些层添加效果更好?这就需要我们不断的进行验证了。经过不断的尝试,我们发现不管是一开始添加还是在中间某些层添加,效果都不够理想。再次回到LPTN网络结构图,我们发现能否在低维特征生成的最后一层添加一个3×3的卷积网络实现低维特征与onehot标签的深度融合,加强图像的特征信息。在3×3的卷积网络后添加LeakyReLU 激活函数,增加了网络的稳定性。最后通过反卷积网络实现图像的恢复。代码如下:

class Trans_low(nn.Module):
    def __init__(self, num_residual_blocks):
        super(Trans_low, self).__init__()

        model = [nn.Conv2d(3, 16, 3, padding=1),
            nn.InstanceNorm2d(16),
            nn.LeakyReLU(),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.LeakyReLU()]

        for _ in range(num_residual_blocks):
            model += [ResidualBlock(64)]

        model += [nn.Conv2d(64, 16, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(16, 3, 3, padding=1)]

        self.model = nn.Sequential(*model)
        self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
        self.actv = nn.LeakyReLU(0.2)
        self.deconv2 = nn.Conv2d(16, 3, 3, 1, 1)

    def forward(self, x, label=None):
       
        labe_low = nn.functional.interpolate(label, size=(x.shape[2], x.shape[3]))
        pyr_A_cat = torch.cat([self.model(x), labe_low], 1)
        pyr_A_cat = self.conv2(pyr_A_cat)
        pyr_A_cat = self.actv(pyr_A_cat)
        pyr_A_cat = self.deconv2(pyr_A_cat)
        out = x + pyr_A_cat
        out = torch.tanh(out)
        return out

改进LPTN结构代码如下:

class LPTN(nn.Module):
    def __init__(self, nrb_low=5, nrb_high=3, num_high=3):
        super(LPTN, self).__init__()

        self.lap_pyramid = Lap_Pyramid_Conv(num_high)
        trans_low = Trans_low(nrb_low)
        trans_high = Trans_high(nrb_high, num_high=num_high)
        self.trans_low = trans_low.cuda()
        self.trans_high = trans_high.cuda()

    def forward(self, real_A_full, onehot_label):

        pyr_A = self.lap_pyramid.pyramid_decom(img=real_A_full)
        onehot_label = onehot_label.unsqueeze(2).unsqueeze(3)
        onehot_label = onehot_label.expand(-1, -1, pyr_A[-1].shape[2], pyr_A[-1].shape[3])
        fake_B_low = self.trans_low(pyr_A[-1], onehot_label)
        real_A_up = nn.functional.interpolate(pyr_A[-1], size=(pyr_A[-2].shape[2], pyr_A[-2].shape[3]))
        fake_B_up = nn.functional.interpolate(fake_B_low, size=(pyr_A[-2].shape[2], pyr_A[-2].shape[3]))
        high_with_low = torch.cat([pyr_A[-2], real_A_up, fake_B_up], 1)
        pyr_A_trans = self.trans_high(high_with_low, pyr_A, fake_B_low)
        fake_B_full = self.lap_pyramid.pyramid_recons(pyr_A_trans)

        return fake_B_full

判别器:

修改codes\models\archs中discriminator_arch.py

添加标签分类器

self.domain_classifier = nn.Conv2d(128, num_domains, 8, padding=0)

完整代码如下:

class Discriminator(nn.Module):
    def __init__(self,num_domains=5):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False),
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.InstanceNorm2d(16, affine=True),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            *discriminator_block(128, 128),
        )
        # Binary classification output
        self.binary_classifier = nn.Conv2d(128, 1, 8, padding=0)

        # Domain classification output
        self.domain_classifier = nn.Conv2d(128, num_domains, 8, padding=0)

    def forward(self, img_input):
        features = self.model(img_input)
        binary_output = self.binary_classifier(features)
        domain_output = self.domain_classifier(features)

        #print("domain_output.view(domain_output.size(0), domain_output.size(1))",domain_output.view(domain_output.size(0), domain_output.size(1)))

        return binary_output, domain_output.view(domain_output.size(0), domain_output.size(1))

3.损失函数修改

修改codes\models中的lptn_model.py模型初始化文件

在feed_data方法中添加onehot标签加载

    def feed_data(self, data):
        self.lq = data['lq'].to(self.device)
        if 'gt' in data:
            self.gt = data['gt'].to(self.device)
        if 'ref' in data:
            self.ref = data['ref'].to(self.device)
        if 'lq_label' in data:
            self.lq_label = data['lq_label'].to(self.device)
        if 'gt_label' in data:
            self.gt_label = data['gt_label'].to(self.device)
        

在optimize_parameters参数初始化方法中修改生成器和判别器的损失函数

修改生成器的输入,判别器的输出,添加标签损失函数:

self.output = self.net_g(self.lq, self.lq_label)


fake_g_pred,fake_g_label_pred = self.net_d(self.output)

#添加标签损失函数
l_g_cls = F.binary_cross_entropy_with_logits(fake_g_label_pred, self.gt_label,
                                                         size_average=False)/fake_g_label_pred.size(0)

通过训练发现,模型生成图像与目标图像颜色存在色差与彩色噪声问题

继续修改生成器损失函数,添加颜色损失函数引导生成器关注颜色特征,添加感知损失函数保证生成图像的纹理特征,代码如下:

 # 计算颜色分布差异
  diff = torch.abs(self.output.float()*255 - self.gt.float()*255)

 # 计算均方误差(MSE)损失
  l_g_h = torch.mean(diff ** 2)

    def init_training_settings(self):
        self.net_g.train()
        self.net_d.train()
        train_opt = self.opt['train']

        # define losses
        if train_opt.get('pixel_opt'):
            pixel_type = train_opt['pixel_opt'].pop('type')
            cri_pix_cls = getattr(loss_module, pixel_type)
            self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
                self.device)
            self.cri_pix_weight = train_opt['pixel_opt'].pop('loss_weight')
        else:
            self.cri_pix = None

        if train_opt.get('l1_opt'):
            l1_type = train_opt['l1_opt'].pop('type')
            cri_l1_cls = getattr(loss_module, l1_type)
            self.cri_l1 = cri_l1_cls(**train_opt['l1_opt']).to(self.device)
            self.cri_l1_weight = train_opt['l1_opt'].pop('loss_weight')
        if train_opt.get('p_opt'):
            p_type = train_opt['p_opt'].pop('type')
            cri_p_cls = getattr(loss_module, p_type)
            self.cri_p = cri_p_cls(**train_opt['p_opt']).to(self.device)

其次,各损失函数的权重值需要自己根据自己的数据集不断地进行调优以获取最佳的效果。

训练数据集的路径、各损失函数权重等在options\train\LPTN\train_FiveK.yml文件修改,最终psnr为32左右。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值