本文作者为 360 奇舞团前端开发工程师
随着AI的火热发展,涌现了一些AI模特换装的前端工具(比如weshop网站),他们是怎么实现的呢?使用了什么技术呢?下文我们就来探索一下其实现原理。
总体的实现流程如下:我们将下图中的这个模特的图片,使用Segment Anything Model在后端分割图层,然后将分割后的图层mask信息返回给前端处理。在前端中选择需要保留的图层信息(如下图中的模特的衣服图层),然后将选中的图层信息交给后端中的Stable Diffusion处理。后端使用原始图片结合选中的图层蒙版图片结合图生图的功能,可以实现weshop等网站的模特换衣等功能。

本文先简单介绍一下使用SAM智能图层分割,然后主要介绍一下在前端中怎么对分割后的图层进行选择的处理流程。
使用SAM识别图层
首先我们需要对图层进行分割,在SAM出来之前,我们需要使用PS将模特的衣服选取出来,然后倒出衣服的模板,然后再使用其他工具进行替换。但是现在有了SAM后,我们可以对图片中的事物进去只能区分,获取各种物品的图层。
Segment Anything Model(SAM)是一种尖端的图像分割模型,可以进行快速分割,为图像分析任务提供无与伦比的多功能性。SAM 的先进设计使其能够在无需先验知识的情况下适应新的图像分布和任务,这一功能称为零样本传输。SAM 使任何人都可以在不依赖标记数据的情况下为其数据创建分段掩码。
要深入了解 Segment Anything 模型和 SA-1B 数据集,请访问Segment Anything 网站(https://segment-anything.com/)并查看研究论文Segment Anything(https://arxiv.org/abs/2304.02643)。
我们使用SAM进行图像分割,将一个图片中的物体分割成不同的部分。
def mask2rle(img):
'''
img: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels = img.T.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
def trans_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=False)
list = []
index = 0
# 对每个注释进行处理
for ann in sorted_anns:
bool_array = ann['segmentation']
# 将boolean类型的数组转换为int类型
int_array = bool_array.astype(int)
# 转化为RLE格式
rle = mask2rle(int_array)
list.append({"index": index, "mask": rle})
index += 1
return list
image = cv2.imread('<your image path>')
import sys
sys.path.append('<your segment-anything link path>')
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# sam 模型路径
sam_checkpoint = '<your sam model path>'
# 根据下载的模型,设置对应的类型
model_type = "vit_h"
# device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
# 处理sam返回的图层信息
mask_list = trans_anns(masks)
mask_obj = {
"height": image.shape[0],
"width": image.shape[1],
"mask_list": mask_list
}
import json
print(json.dumps(mask_obj))
运行以上python代码之前,需要配置sam的python环境,具体的配置描述请查看sam的官方描述。
我们通过以上代码,将我们提供的图片,通过SAM处理后,返回图层分割数据。在trans_anns方法中,将图层按照area从小到大的顺序排序。遍历各个图层,将boolean类型的数组转换为 0 1 int类型,然后对二维numpy array类型的0 1二进制mask图像转换为RLE格式。
RLE是一种简单的无损数据压缩算法,通常用于表示连续的相同值的序列。RLE编码的字符串通常用于在图像分割等任务中存储和传输二进制掩码信息,以便更有效地表示图像中的目标区域。并且方便数据压缩和传输。我们参照的这种编解码方式。也可以使用coco RLE的编解码方式。
将编码后的各图层信息存储到list中,就可以通过接口传输给前端处理了。
前端选择图层
下面这些是本文的重点,在前端将刚才解析后的mask_list信息展示,并可以通过交互选取需要保留的模版,并生成最终合并选取的mask生成一个需要保留的服装模版。
body中的基本组件为
<div id="layer-box" style=" width: 500px; height: 500px;position: relative">
<img style="width: 100%; height: 100%; position: absolute" src="https://p0.ssl.qhimg.com/t01989f0d446bed3e58.jpg" />
</div>
<div id="save" @click="save" style="margin-top: 20px;margin-right: 20px; margin-left: 20px;">保存</div>
<canvas id="mergedCanvas" style="border:1px solid #000;"></canvas>
id为layer-box的div组件作为各个mask的父组件,用于查找和管理各个mask的隐藏和展示。其子组件中的第一个标签是展示原始的模特图片的。
id为save的组件在点击时可以处理保存选中的各个mask为一个新的mask图片,用于处理图片合成。
id为mergedCanvas的canvas是进行图片合成和展示合成后的图片的。
解析SAM处理后的mask_list信息
/**
* rle格式图片信息转换为mask信息
*/
function rle2mask(mask_rle, shape = [500, 500]) {
/*
mask_rle: run-length as string forma