虚拟试穿
简介:本文梳理虚拟试穿算法框架结构,展示模特虚拟试穿上衣的效果,细说设计流程的详细步骤,提供相应的数据资源。
- 算法仓库:https://github.com/beauthy/DeepFashion_Try_On
github上不了,就访问:码云:虚拟试穿上衣测试:https://gitee.com/rpr/try-on_parse.git - 链接:https://pan.baidu.com/s/1nKUevnIMcGjaitVwIb7SRg
提取码:59wk - 测试用模型,鼓励大家根据网络训练自己的模型。
具体效果看下图or视频,想测试可以参见模型资源下载,之后Load和测试。
上述模型资源包括算法所设及的全部网络模型:latest_net_U.pth,latest_net_G1.pth,latest_net_G2.pth,latest_net_G.pth
测试效果:如图

测试效果:如视频
计算机视觉神经网络虚拟试穿测试
前言
本文将梳理算法实现过程原理。
提示:本文内容仅供学术研究与参考。
一、Try_On算法里面有什么?
0.环境; 1. 数据读取; 2. 数据模型:U-Net,G-Net; 3.损失函数; 4.调试常见的bug。二、梳理步骤
1.环境
代码如下(示例):




以上只等下次优化成requirements.txt再传上来。
2.读入数据
模型输入数据需要哪些呢?

测试数据集长什么样?

数据直观内容分析,我把模型需要的输入放一起,展示如下:

实际上,pose_关键点数据,和label_分割数据,是img_模特数据得到的(怎么生成关键点数据和人物分割数据的详细解读和代码,我再开一篇博客放上来);edge_数据就是待穿衣服color生成的。mask掩码数据是根据需要随机生成的。所以,完整的项目,的输入只需要模特和服装款式即可,也就是说可以实现给个人物和一件衣服就给实现换装。
再看,

看具体情况:通过photoshop的拾色器可以直观看到数据的值如下(label是灰度图):

背景的亮度L:0;面部的亮度L:9,左胳膊的亮度L:10,右胳膊的亮度L:8,上衣衣服位置的L:2。
把肢体图像分割出精确部分,用不同的亮度表示,到时候换衣服就有边界了。学姿势,纹理和褶皱等也有边界。
注意此处的L并不是该位置的像素值,只是亮度值。像素值可以用代码打印出来看。一下给出不同块儿的像素值。
# 具体划分区域Segmentation Label
0 -> Background
1 -> Hair
4 -> Upclothes
5 -> Left-shoe
6 -> Right-shoe
7 -> Noise
8 -> Pants
9 -> Left_leg
10 -> Right_leg
11 -> Left_arm
12 -> Face
13 -> Right_arm
注:名字带mask的三张掩码图,黑色区域亮度0,白色区域亮度为100,它们没有实际意义,可用于增加噪声,让模型稳定性好一些(我是这样理解的,因为训练的时候中间结果也有损失函数的backforword).
ok,以上就是输入数据,理清了没?
下面分析怎么使用的,以及后面模型是怎么组合的。
3. 输入配置
怎么读取数据集,生成模型输入需要的数据?项目写了一个配置options,专用于对数据集目录信息,训练测试信息和超参数进行配置的文件。

test.py文件中:
opt = TrainOptions().parse()
ctrl+鼠标左键点击TestOptions,找到opt对象的具体内容:
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
......
ctrl+鼠标左键点击BaseOptions:
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
.....
TestOptions类的initialize函数系重写,但还是调用了BaseOptions.initialize(self)的,所以BaseOptions.initialize的数据也包含了的。根据训练和测试所需要的数据不同,控制生成数据集和其他超参数。
4. 数据集处理详细
生成数据迭代器,同其他pytorch自制数据集相差不大。重点文件时aligned_dataset.py

具体一起来看看:
test.py文件中:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
CreateDataLoader是封装数据迭代器的类:
点进去看一眼
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
具体内容在CustomDatasetDataLoader:
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
发现还封装了一层,数据是:self.dataset = CreateDataset(opt):
发现还有函数封装,
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
注意dataset.initialize(opt),封装数据过程中动不动都在初始化。
继续AlignedDataset,ctrl+鼠标左键点击AlignedDataset
class AlignedDataset(BaseDataset):
def initialize(self, opt):
......
def __getitem__(self, index):
......
找到,def __getitem__(self, index):,看到它是不是很眼熟了,就是pytorch生成batch数据可迭代数据。这个类继承于BaseDataset,父类有transform等方法:
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
#flip = random.random() > 0.5
flip = 0
return {
'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
osize = [256,192]
transform_list.append(transforms.Scale(osize, method))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w

本文深入解析了虚拟试穿算法的框架,涵盖数据读取、模型结构(如U-Net、G-Net)、数据预处理、关键步骤和网络训练过程,包括如何使用pose关键点和衣物分割数据。通过实例和代码演示,指导读者如何搭建并测试模型。
最低0.47元/天 解锁文章
4797

被折叠的 条评论
为什么被折叠?



