利用CLIP做text-image对数据集

CLIP


有相应的文本,图像;制作文本图像对


一、图像处理?

代码如下(示例):

def crop_to_square(img):
    size = 512
    image_transforms = transforms.Compose(
        [
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
        ]
    )
    return image_transforms(img)

二、CLIP类

CLIP相关方法

代码如下(示例):

class CLIP(object):
    def __init__(self):
        self.device = "cuda"
        model, preprocess = clip.load("ViT-B/32", device=self.device)
        tokenizer = clip.tokenize
        model = model.cuda()
        self.model = model
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def text_emb(self, text_ls):
        if isinstance(text_ls, str):
            text_ls = [text_ls]
        text = self.tokenizer(text_ls, truncate=True).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text)
        return text_features

    def img_emb(self, img):
        image = self.preprocess(img).unsqueeze(0).to("cuda")
        with torch.no_grad():
            image_features = self.model.encode_image(image)
        return image_features

    def __call__(self, image, text, softmax=False):
        if isinstance(text, str):
            text = [text]

        if isinstance(image, list):
            image = [self.preprocess(i).unsqueeze(0).to("cuda") for i in image]
            image = torch.concat(image)
        else:
            image = self.preprocess(image).unsqueeze(0).to("cuda")

        text = self.tokenizer(text).to(self.device)

        if softmax:
            with torch.no_grad():
                logits_per_image, logits_per_text = self.model(image, text)
                probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            return probs
        else:
            with torch.no_grad():
                image_features = self.model.encode_image(image)
                text_features = self.model.encode_text(text)

                image_features /= image_features.norm(dim=-1, keepdim=True)
                text_features /= text_features.norm(dim=-1, keepdim=True)
                similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
                s = similarity[0][0]

            return s
import os
import sys
import torch
from PIL import Image
import glob
import pickle
import argparse
from torchvision import transforms
import numpy as np
import random
from sklearn.metrics.pairwise import cosine_similarity
import clip

def main():
    clip_model = CLIP()
    data_dir = args.directory

    source_concept = args.concept
    os.makedirs(args.outdir, exist_ok=True)
    all_data = glob.glob(os.path.join(data_dir, "*.p"))
    res_ls = []
    for idx, cur_data_f in enumerate(all_data):
        cur_data = pickle.load(open(cur_data_f, "rb"))
        cur_img = Image.fromarray(cur_data["img"])
        cur_text = cur_data["text"]

        cur_img = crop_to_square(cur_img)
        score = clip_model(cur_img, "a photo of a {}".format(source_concept))
        ####选取置性度高于0.24
        if score > 0.24:
            res_ls.append((cur_img, cur_text))

    if len(res_ls) < args.num:
        Exception("Not enough data from the source concept to select from. Please add more in the folder. ")

    all_prompts = [d[1] for d in res_ls]
    text_emb = clip_model.text_emb(all_prompts)
    text_emb_target = clip_model.text_emb("a photo of a {}".format(source_concept))
    text_emb_np = text_emb.cpu().float().numpy()
    text_emb_target_np = text_emb_target.cpu().float().numpy()
    res = cosine_similarity(text_emb_np, text_emb_target_np).reshape(-1)
    candidate = np.argsort(res)[::-1][:300]
    random_selected_candidate = random.sample(list(candidate), args.num)
    final_list = [res_ls[i] for i in random_selected_candidate]
    for i, data in enumerate(final_list):
        img, text = data
        cur_data = {
            "img": np.array(img),
            "text": text,
        }
        pickle.dump(cur_data, open(os.path.join(args.outdir, "{}.p".format(i)), "wb"))


def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--directory', type=str,
                        help="", default='')
    parser.add_argument('-od', '--outdir', type=str,
                        help="", default='')
    parser.add_argument('-n', '--num', type=int,
                        help="", default=100)
    parser.add_argument('-c', '--concept', type=str, required=True,
                        help="")
    return parser.parse_args(argv)


if __name__ == '__main__':
    args = parse_arguments(sys.argv[1:])
    main()
### 技术原理 Text-to-Image 生成技术的核心在于将自然语言描述转换为对应的视觉内容。这一过程通常涉及以下几个关键步骤: 1. **文本编码**:首先,输入的文本通过一个文本编码器(如BERT或CLIP)进行处理,将其转化为高维向量表示。这些向量捕捉了文本中的语义信息,为后续的图像生成提供基础。 2. **条件生成模型**:接下来,编码后的文本作为条件输入到生成模型中。常见的生成模型包括GANs(生成对抗网络)、VAEs(变分自编码器)和扩散模型(Diffusion Models)。其中,扩散模型因其在高质量图像生成方面的出色表现而受到广泛关注。扩散模型通过逐步添加噪声来破坏训练数据,然后学习逆转这一过程以从随机噪声中恢复原始数据[^3]。 3. **去噪过程**:在扩散模型中,去噪是一个至关重要的步骤。模型通过多轮迭代逐渐去除生成图像中的噪声,最终得到清晰、符合文本描述的图像。这一过程可以看作是从随机噪声开始,逐步调整图像特征,使其更接近目标文本描述的过程[^1]。 4. **多模态融合**:为了更好地结合文本图像信息,一些先进的模型采用多模态融合策略。例如,ControlNet [9] 允许用户通过额外的结构化输入(如边缘图、深度图等)进一步控制生成结果,从而提高生成图像的质量和可控性[^3]。 ### 常用工具 目前市面上有多种成熟的 Text-to-Image 工具和框架,以下是几个较为知名的工具: 1. **Stable Diffusion**:Stable Diffusion 是一种基于扩散模型的文本图像生成工具,由 LAION 数据集训练而成。它能够根据简短的文本描述生成高质量的图像,并支持多种定制化选项。此外,Stable Diffusion 还可以通过替换图像编码器来实现图像提示功能,尽管这种方法可能会限制模型的灵活性。 2. **DALL-E**:DALL-E 是由 OpenAI 开发的一款文本图像生成模型,基于 Transformer 架构。它可以生成具有高度创意性和多样性的图像,适用于艺术创作、设计等领域。DALL-E 的最新版本 DALL-E 3 支持更复杂的文本描述,并能生成更高分辨率的图像。 3. **TISE (Text-to-Image Synthesis Evaluation)**:TISE 是一个用于评估文本生成图像质量的 Python 工具箱。它提供了多种评估指标,帮助研究人员和开发者量化生成图像的质量和多样性。这对于优化生成模型和提升用户体验非常有用[^4]。 4. **ControlNet**:ControlNet 是一种增强型扩散模型插件,允许用户通过额外的结构化输入(如边缘图、深度图等)来控制生成图像的细节。这使得生成的图像更加符合用户的预期,特别是在需要精确布局和结构的应用场景中。 ### 示例代码 以下是一个使用 Stable Diffusion 模型生成图像的简单示例代码: ```python from diffusers import StableDiffusionPipeline import torch # 加载预训练的 Stable Diffusion 模型 pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) pipe = pipe.to("cuda") # 将模型移动到 GPU 上 # 输入文本描述 prompt = "A beautiful sunset over the ocean with a boat in the distance" # 生成图像 image = pipe(prompt).images[0] # 保存生成的图像 image.save("generated_image.png") ``` 这段代码展示了如何使用 Hugging Face 提供的 `diffusers` 库加载预训练的 Stable Diffusion 模型,并根据给定的文本描述生成图像。生成的图像会被保存为 `generated_image.png` 文件。 ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值