音视频开发之旅(97) - 多模态之LLava

目录

1. 数据集构建

2. LLava网络结构

3. 源码分析

4. 资料

LLava使用视觉编码器(CLIP-VIT)和大语言模型(LLama)构建多模态大模型,通过GPT-4生成的数据样本进行指令微调.

一. 数据集构建

利用GPT-4生成样本,具体步骤如下:

将图像以Caption(描述性标题)和Bounding boxes(边界框)的形式输入给GPT-4,结合特定的prompt,生成3中类型的instruction-following(指令跟随)样本:Conversation(对话), Detailed description(详细描述)和Complex reasoning(复杂推理), 如下图所示:

图片

二. LLava网络结构

图片

llava网络结构有3部分组成:视觉编码器(Vision Encoder),投影层(Projection W) 和大语言模型(Language Model)

视觉编码器使用预训练的CLIP-VIT编码器,输入图像输出视觉特征Z_v

投影层,通过一个可训练的投影矩阵W 将图像特征转换为语言Embeding H_v,使得视觉特征能够与语言特征的Embeding空间兼容.

转换公式如下

图片

然后将视觉Embeding H_v和文本Embeding H_q 拼接在一起, 输入给LLM,生成X_a

LLava的使用两个阶段的方式来进行训练

第一阶段:预训练 frozen(冻结)视觉编码器(CLIP-VIT)和大语言模型,从CC3M数据中过滤595K图像-文本对,只训练Project W权重

第二阶段:微调, 只冻结视觉编码器权重,更新Project和LLM大模型的权重

三. 源码解析

3.1 基本使用

class LLavaModel():    def __init__(self):        self.model_path = "llava-v1.5-7b"        #model_name=llava-v1.5-7b        self.model_name=get_model_name_from_path(self.model_path)        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(            model_path=self.model_path,            model_base=None,            model_name=self.model_name,            load_4bit=True        )        self.prompt = "What are the things I should be cautious about when I visit here?"
    def inference_image(self,image_file):        args = type('Args', (), {            "model_name": self.model_name,            "query": self.prompt,            "image_file": image_file,            "sep": ",",            "temperature": 0,            "top_p": None,            "num_beams": 1,            "max_new_tokens": 512        })()        return inference_model(args,self.tokenizer, self.model, self.image_processor, self.context_len)

3.2 模型加载

def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):    kwargs = {"device_map": device_map, **kwargs}
    if load_8bit:        kwargs['load_in_8bit'] = True    elif load_4bit:#int4量化,降低显存占用        kwargs['load_in_4bit'] = True        kwargs['quantization_config'] = BitsAndBytesConfig(            load_in_4bit=True,            bnb_4bit_compute_dtype=torch.float16,            bnb_4bit_use_double_quant=True,            bnb_4bit_quant_type='nf4'        )    else:        kwargs['torch_dtype'] = torch.float16
    #Flash Attention,是一种用于提高注意力机制效率的技术,用来减少计算复杂度和内存占用,同时保持或提高模型性能    if use_flash_attn:        kwargs['attn_implementation'] = 'flash_attention_2'
    # Load LLaVA model    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)    model = LlavaLlamaForCausalLM.from_pretrained(        model_path,        low_cpu_mem_usage=True,        **kwargs    )    image_processor = None    #len(tokenizer):32000    model.resize_token_embeddings(len(tokenizer))
    #config中配置的vision_tower,比如这里是:clip-vit-large-patch14-336    vision_tower = model.get_vision_tower()    vision_tower.load_model(device_map=device_map)    image_processor = vision_tower.image_processor    context_len = 2048    return tokenizer, model, image_processor, context_len

3.3 推理流程

def inference_model(args,tokenizer, model, image_processor, context_len):
    model_name = args.model_name
    qs = DEFAULT_IMAGE_TOKEN + "\n"
    conv_mode = "llava_v1"
    """    根据不同模型,有不同的映射模板,例如:llava_v1对应的配置如下:    conv_llava_v1 = Conversation(        system="A chat between a curious human and an artificial intelligence assistant. "            "The assistant gives helpful, detailed, and polite answers to the human's questions.",        roles=("USER", "ASSISTANT"),        version="v1",        messages=(),        offset=0,        sep_style=SeparatorStyle.TWO,        sep=" ",        sep2="</s>",    )    """        conv = conv_templates[conv_mode].copy()    #下面两行代码添加的messages=[['USER', '<image>\n'], ['ASSISTANT', None]]    conv.append_message(conv.roles[0], qs)    conv.append_message(conv.roles[1], None)        #prompt为:"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n ASSISTANT:"    prompt = conv.get_prompt()
    image_files = image_parser(args)    #这里其实使用Image.open(image_file).convert('RGB'),convert('RGB')的作用是统一颜色模式,确保数据的一致性和兼容性    images = load_images(image_files)    image_sizes = [x.size for x in images]
    #图像处理器,是一些常规的裁剪 反转 归一化等处理    images_tensor = process_images(        images,        image_processor,        model.config    ).to(model.device, dtype=torch.float16)
    """    把prompt转为输入Embeding        eg:"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n ASSISTANT:'""    转为    tensor([[    1,   319, 13563,  1546,   263, 12758,  5199,   322,   385, 23116,         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,           322,  1248,   568,  6089,   304,   278,  5199, 29915, 29879,  5155,         29889,  3148,  1001, 29901, 29871,  -200, 29871,    13,   319,  1799,          9047, 13566, 29901]], device='cuda:7')    """    input_ids = (        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")        .unsqueeze(0)        .cuda()    )
    with torch.inference_mode():        #进行内容理解,model为:LlavaLlamaForCausalLM生成outputEmbeding        output_ids = model.generate(            input_ids,            images=images_tensor,            image_sizes=image_sizes,            do_sample=True if args.temperature > 0 else False,            temperature=args.temperature,            top_p=args.top_p,            num_beams=args.num_beams,            max_new_tokens=args.max_new_tokens,            use_cache=True,        )
    #把token解码为str 输出    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()    return outputs

3.4 图像嵌入和文本嵌入的融合

LlavaMetaForCausalLM#prepare_inputs_labels_for_multimodal 将图像的嵌入(embedding)与文本的嵌入结合起来,形成一个统一的输入,并最终传递给大模型(例如 LLaMA)进行因果语言建模任务的推理

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):        @torch.no_grad()    def generate(        self,        inputs: Optional[torch.Tensor] = None,        images: Optional[torch.Tensor] = None,        image_sizes: Optional[torch.Tensor] = None,        **kwargs,    ) -> Union[GenerateOutput, torch.LongTensor]:        position_ids = kwargs.pop("position_ids", None)        attention_mask = kwargs.pop("attention_mask", None)        if "inputs_embeds" in kwargs:            raise NotImplementedError("`inputs_embeds` is not supported")
        if images is not None:            (                inputs,                position_ids,                attention_mask,                _,                inputs_embeds,                _            ) = self.prepare_inputs_labels_for_multimodal(                inputs,                position_ids,                attention_mask,                None,                None,                images,                image_sizes=image_sizes            )        else:            inputs_embeds = self.get_model().embed_tokens(inputs)
        return super().generate(            position_ids=position_ids,            attention_mask=attention_mask,            inputs_embeds=inputs_embeds,            **kwargs        )

四. 资料

论文Visual Instruction Tuning(LLaVA):https://arxiv.org/abs/2304.08485

论文Improved Baselines with Visual Instruction Tuning(LLaVA-1.5):https://arxiv.org/abs/2310.03744

LLaVA: https://github.com/haotian-liu/LLaVA

LLaVA-NeXT:https://github.com/LLaVA-VL/LLaVA-NeXT

微软LLaVa模型论文笔记:https://zhuanlan.zhihu.com/p/625723805

多模态大模型LLaVA https://blog.51cto.com/u_16600817/10098069

详解多模态大模型:LLaVA+LLaVA1.5+LLaVA-Med https://posts.careerengine.us/p/65bf9ba51fb10327d7178303

Emo Visual Data 表情包视觉数据集 https://hyper.ai/datasets/32012

多模态大模型LLaVA模型讲解 https://www.bilibili.com/video/BV1nw4m1S7nZ

训练LLava模型(数据集构建,基于Trainer的训练框架搭建):https://www.bilibili.com/video/BV1Si421v7j1

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值