利用CLIP做text-image对数据集

CLIP


有相应的文本,图像;制作文本图像对


一、图像处理?

代码如下(示例):

def crop_to_square(img):
    size = 512
    image_transforms = transforms.Compose(
        [
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
        ]
    )
    return image_transforms(img)

二、CLIP类

CLIP相关方法

代码如下(示例):

class CLIP(object):
    def __init__(self):
        self.device = "cuda"
        model, preprocess = clip.load("ViT-B/32", device=self.device)
        tokenizer = clip.tokenize
        model = model.cuda()
        self.model = model
        self.preprocess = preprocess
        self.tokenizer = tokenizer

    def text_emb(self, text_ls):
        if isinstance(text_ls, str):
            text_ls = [text_ls]
        text = self.tokenizer(text_ls, truncate=True).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text)
        return text_features

    def img_emb(self, img):
        image = self.preprocess(img).unsqueeze(0).to("cuda")
        with torch.no_grad():
            image_features = self.model.encode_image(image)
        return image_features

    def __call__(self, image, text, softmax=False):
        if isinstance(text, str):
            text = [text]

        if isinstance(image, list):
            image = [self.preprocess(i).unsqueeze(0).to("cuda") for i in image]
            image = torch.concat(image)
        else:
            image = self.preprocess(image).unsqueeze(0).to("cuda")

        text = self.tokenizer(text).to(self.device)

        if softmax:
            with torch.no_grad():
                logits_per_image, logits_per_text = self.model(image, text)
                probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            return probs
        else:
            with torch.no_grad():
                image_features = self.model.encode_image(image)
                text_features = self.model.encode_text(text)

                image_features /= image_features.norm(dim=-1, keepdim=True)
                text_features /= text_features.norm(dim=-1, keepdim=True)
                similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
                s = similarity[0][0]

            return s
import os
import sys
import torch
from PIL import Image
import glob
import pickle
import argparse
from torchvision import transforms
import numpy as np
import random
from sklearn.metrics.pairwise import cosine_similarity
import clip

def main():
    clip_model = CLIP()
    data_dir = args.directory

    source_concept = args.concept
    os.makedirs(args.outdir, exist_ok=True)
    all_data = glob.glob(os.path.join(data_dir, "*.p"))
    res_ls = []
    for idx, cur_data_f in enumerate(all_data):
        cur_data = pickle.load(open(cur_data_f, "rb"))
        cur_img = Image.fromarray(cur_data["img"])
        cur_text = cur_data["text"]

        cur_img = crop_to_square(cur_img)
        score = clip_model(cur_img, "a photo of a {}".format(source_concept))
        ####选取置性度高于0.24
        if score > 0.24:
            res_ls.append((cur_img, cur_text))

    if len(res_ls) < args.num:
        Exception("Not enough data from the source concept to select from. Please add more in the folder. ")

    all_prompts = [d[1] for d in res_ls]
    text_emb = clip_model.text_emb(all_prompts)
    text_emb_target = clip_model.text_emb("a photo of a {}".format(source_concept))
    text_emb_np = text_emb.cpu().float().numpy()
    text_emb_target_np = text_emb_target.cpu().float().numpy()
    res = cosine_similarity(text_emb_np, text_emb_target_np).reshape(-1)
    candidate = np.argsort(res)[::-1][:300]
    random_selected_candidate = random.sample(list(candidate), args.num)
    final_list = [res_ls[i] for i in random_selected_candidate]
    for i, data in enumerate(final_list):
        img, text = data
        cur_data = {
            "img": np.array(img),
            "text": text,
        }
        pickle.dump(cur_data, open(os.path.join(args.outdir, "{}.p".format(i)), "wb"))


def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--directory', type=str,
                        help="", default='')
    parser.add_argument('-od', '--outdir', type=str,
                        help="", default='')
    parser.add_argument('-n', '--num', type=int,
                        help="", default=100)
    parser.add_argument('-c', '--concept', type=str, required=True,
                        help="")
    return parser.parse_args(argv)


if __name__ == '__main__':
    args = parse_arguments(sys.argv[1:])
    main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值