原文链接:https://arxiv.org/abs/2304.08485
原文摘要:使用机器生成的指令遵循数据对大型语言模型(LLMs)进行指令微调,已被证明可以提高在新任务上的零样本能力,但在多模态领域的探索较少。我们首次尝试使用仅语言的GPT-4来生成多模态语言-图像指令遵循数据。通过对这些生成数据进行指令微调,引入LLaVA:大型语言和视觉助手,一种端到端训练的大型多模态模型,用于通用视觉和语言理解。为了促进未来在视觉指令遵循方面的研究,我们构建了两个具有多样化和具有挑战性的应用导向任务的评估基准。实验表明,LLaVA展示了令人印象深刻的多模态聊天能力,有时在未见过的图像/指令上表现出与多模态GPT-4相似的行为,并且在合成多模态指令遵循数据集上相对于GPT-4的相对得分为85.1%。在Science QA上进行微调时,LLaVA和GPT-4的协同作用实现了92.53%的新纪录准确率。
一、核心要点
LLaVA 是一种大型多模态模型,通过指令微调将视觉编码器与大型语言模型(LLM)相结合,显著提升了视觉和语言任务的性能,尤其在多模态对话和指令遵循方面表现出色。
二、研究背景
- 人类与世界交互的多模态需求:人类通过多种渠道(如视觉和语言)与世界交互,每种渠道都有独特优势,能够帮助人类更好地理解世界。开发能够遵循多模态视觉-语言指令的通用助手是人工智能的一个核心目标。
- 现有研究的局限性:在计算机视觉领域,虽然已有研究致力于开发语言增强型基础视觉模型,但这些模型通常仅针对特定任务设计,且语言仅用于描述图像内容,缺乏对用户指令的交互性和适应性。另一方面,大型语言模型(LLMs)在遵循文本指令方面表现出色,但尚未在多模态领域得到充分探索。
三、主要贡献
- 多模态指令遵循数据生成:提出了一种数据转换视角和管道,利用ChatGPT/GPT-4将图像-文本对转换为指令遵循格式,解决了多模态视觉-语言指令遵循数据匮乏的问题。
- 大型多模态模型开发:构建了一个大型多模态模型LLaVA,通过连接CLIP视觉编码器和Vicuna语言解码器,并在生成的指令视觉-语言数据上进行端到端微调,验证了使用生成数据进行多模态模型指令微调的有效性。
- 多模态指令遵循基准构建:提出了LLaVA-Bench,包含两个具有挑战性的基准测试,涵盖了多样化且具有挑战性的应用导向任务,用于评估模型的多模态指令遵循能力。
- 开源资源:公开了生成的多模态指令数据、代码库、模型检查点以及视觉聊天演示,为后续研究提供便利。
四、前置知识
4.1 指令微调
4.1.1 指令微调的定义
指令微调是一种针对预训练语言模型的微调方法,通过在模型训练过程中引入指令数据集,让模型学习如何理解和执行人类的指令。这些指令数据集通常包含各种形式的指令及其对应的正确响应,模型通过学习这些指令-响应对,逐渐提升对指令的理解能力和执行能力。
4.1.2 指令微调的优势
- 使模型能够更好地理解和遵循人类指令,输出更符合人类意图的响应
- 泛化能力提升,较好的zero-shot性能
- 数据需求少、计算成本低
4.1.3 LLM中的三种微调范式对比
特点 | 模型微调 | 提示学习 | 指令微调 |
---|---|---|---|
定义 | 用特定领域的数据对模型训练,以适应特定任务。 | 在模型输入中添加提示(prompt),引导模型输出期望的结果。 | 使用包含指令和相应回答的数据训练,使模型能够理解和执行指令。 |
参数更新 | 更新模型的部分或全部参数。 | 通常不更新模型参数 | 更新模型参数,使模型能够更好地理解和生成指令相关的输出。 |
数据需求 | 需要大量特定领域的标注数据。 | 需要少量的提示样例。 | 需要包含指令和回答的数据集。 |
适用场景 | 适用于需要模型在特定任务上达到高精度的场景。 | 适用于需要快速调整模型输出以适应特定格式或风格的场景。 | 适用于需要模型具有对话能力和指令执行能力的场景。 |
它们的微调流程可以表示成下图:
其中,指令微调和提示微调的目的都是去挖掘语言模型本身具备的知识。但不同的是,Prompt 是激发语言模型的补全能力,是针对一个任务的;Instruct 则是激发语言模型的理解能力,通过给出更明显的指令,让模型去做出正确的行动,具有更好的泛化能力。
4.2 CLIP
CLIP(Contrastive Language–Image Pre-training)是一个由 OpenAI 提出的多模态模型,旨在学习图像和文本之间的对齐关系。通过将匹配的图像和文本对拉近,将不匹配的对推远,从而实现图像和文本之间的语义对齐。
4.2.1 对比学习阶段
- 给定一个 Batch 的 N N N 个(图片,文本)对,图片输入给 Image Encoder 得到表征 I 1 , I 2 , … , I N I_1, I_2, \ldots, I_N I1,I2,…,IN,文本输入给 Text Encoder 得到表征 T 1 , T 2 , … , T N T_1, T_2, \ldots, T_N T1,T2,…,TN
- N N N个Image Embedding与 N N N个Text Embedding两两配对,组成了一个 N ∗ N N*N N∗N的余弦相似性矩阵, 用于loss计算。其中,只有 ( I j , T j ) (I_j, T_j) (Ij,Tj) 属于正样本, ( I j , T k ) (I_j, T_k) (Ij,Tk) ( j ≠ k j \ne k j=k)都属于负样本。
- 使用对称交叉熵损失作为优化目标, 最大化 N N N 个正样本的余弦相似度,最小化 N 2 − N N^2 - N N2−N 个负样本的余弦相似度。
# extract feature representations of each modality
I_f = image_encoder(I) #[n,d_il 可以是ResNet or Vision Transformer
T_f = text_encoder(T) #[n,d_t] 可以是 CBOW(Continuous Bag-Of-Words) or Text Transformer
# joint multimodal embedding [n, d_e]
I_e = L2_normalize(np.dot(I_f, W_i),axis=1)
T_e = L2_normalize(np.dot(T_f, w_t),axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross.entropy_loss(logits, labels,axis=0)
loss_t = cross.entropy_loss(logits,labels,axis=1)
loss = (loss_i + loss_t) / 2
4.2.2 预测阶段
把图片输入 CLIP 预训练好的 Image Encoder,得到特征 I 1 I_1 I1,接下来把所有类别的词汇 “cat”, “dog” 等,做成 prompt:“A photo of a {object}”(在做图像分类时,要分类的类别是一个个的单词,然而CLIP预训练时候的文本端采用的是句子进行训练。所以在预测的时候也要和预训练阶段保持对齐.),并将这个 prompt 输入 CLIP 预训练好的 Text Encoder,依次得到特征 T 1 , T 2 , … , T N T_1, T_2, \ldots, T_N T1,T2,…,TN ,最后观察哪个的余弦相似度和 I 1 I_1 I1最高,就代表该图片是哪个类别的。
4.2.3 CLIP-ViT-L/14视觉编码器
CLIP-ViT-L/14 是 CLIP 模型中使用的一种特定的视觉编码器,基于 Vision Transformer(ViT)架构,具有较大的模型规模(L 表示 Large),并且使用了 14×14 的图像块(patches)作为输入的基本单元。
具体步骤:
- 输入处理 :对于输入图像,CLIP-ViT-L/14 首先将其分割成 14×14 的图像块,然后将每个图像块展平并映射到一个固定维度的特征空间。
- ViT处理 :经过位置编码处理后的图像块特征序列被输入到 Transformer 编码器中。通过多头自注意力机制和前馈网络的处理,模型能够捕捉图像块之间的复杂关系,并生成图像的全局特征表示。
- 输出 :一个高维的视觉特征向量。
4.3 Vicuna
4.3.1 Vicuna 简介
Vicuna是LLaMA模型的一个指令微调版本,由UC伯克利的研究团队开发。该模型在LLaMA的基础上,针对指令性任务进行了优化。通过引入指令嵌入(instruction embedding)和任务嵌入(task embedding),Vicuna能够更好地理解任务的意图,并生成符合指令要求的输出。
4.3.2 Vicuna 工作原理
- 输入:输入文本首先被分词器转换为一系列的标记,每个标记对应一个嵌入向量,这些嵌入向量作为模型的输入。
- 自注意力机制:在解码器层中,多头自注意力模块通过计算每个标记与其他标记之间的关系权重,动态关注输入序列中的不同部分,从而捕捉文本中的长距离依赖关系和上下文信息。
- 前馈网络:多层感知机对自注意力模块的输出进行非线性变换,进一步提取特征并增强模型的表达能力。
- 输出:经过多层解码器的处理后,模型输出每个位置上下一个标记的概率分布,通过选择概率最高的标记来生成文本序列。
五、LLaVA 模型架构
5.1 模型结构
LLaVA模型由CLIP视觉编码器和Vicuna语言解码器组成。视觉编码器将图像特征转换为与语言模型词嵌入空间维度相同的视觉标记,然后通过语言模型进行端到端训练。
为什么使用CLIP-ViT-L/14作为视觉编码器?
CLIP-ViT-L/14作为CLIP的一部分,已经在大规模的图像-文本对数据上进行了预训练,能够将图像特征映射到与文本特征相同的空间中。这种对齐能力使得LLaVA能够更好地将视觉信息与语言模型相结合,实现更自然、更准确的多模态交互。
为什么使用Vicuna作为语言解码器?
因为Vicuna模型在语言任务中具有最佳的指令遵循能力,是公开可用检查点中表现最佳的,为LLaVA提供了强大的语言生成能力。同时,Vicuna高效的微调策略允许在有限的计算资源下快速适应新任务,这对于LLaVA在多模态任务中快速调整语言解码器以匹配视觉编码器至关重要。
5.2 模型处理流程
- 输入:对于每张图像
X
v
X_v
Xv,和关于该图像的多轮对话数据
(
X
q
1
,
X
a
1
,
⋯
,
X
q
T
,
X
a
T
)
(X_q^1, X_a^1, \cdots, X_q^T, X_a^T)
(Xq1,Xa1,⋯,XqT,XaT),
第 t t t 轮的输入指令 X i n s t r u c t t X_{instruct}^t Xinstructt 定义为:
X i n s t r u c t t = { 随机选择 [ X q 1 , X v ] 或 [ X v , X q 1 ] , 第一轮 t = 1 X q t , 剩余轮次 t > 1 X_{instruct}^t = \left\{ \begin{array}{ll} \text{随机选择 } [X_q^1, X_v] \text{ 或 } [X_v, X_q^1], & \text{第一轮 } t = 1 \\ X_q^t, & \text{剩余轮次 } t > 1 \end{array} \right. Xinstructt={随机选择 [Xq1,Xv] 或 [Xv,Xq1],Xqt,第一轮 t=1剩余轮次 t>1
-
视觉特征提取和语言指令嵌入 :使用 CLIP 视觉编码器对输入图像 X v X_v Xv 进行编码,得到视觉特征 Z v Z_v Zv。然后通过一个可训练的线性投影矩阵 W W W,将 Z v Z_v Zv 转换为与语言模型词嵌入维度相同的视觉令牌 H v H_v Hv。
H v = W ⋅ Z v , Z v = g ( X v ) \mathbf{H_v} = \mathbf{W} \cdot \mathbf{Z_v}, \mathbf{Z_v} = g(\mathbf{X_v}) Hv=W⋅Zv,Zv=g(Xv)语言指令嵌入:将语言指令 X q X_q Xq转换为嵌入向量 H q H_q Hq, E \mathbf{E} E表示嵌入函数,通常包括词嵌入层和位置编码层。
H q = E ( X q ) \mathbf{H_q} = \mathbf{E}(\mathbf{X_q}) Hq=E(Xq) -
组织多模态输入序列 :将视觉令牌 H v H_v Hv 与每轮对话的指令编码 H q t H^t_{q} Hqt和之前的输入序列组织起来形成多模态输入序列,使得模型在每轮对话中都能同时考虑图像信息和之前的对话历史。
-
Vicuna解码生成回答:将构建好的多模态输入序列输入到Vicuna语言模型中,语言模型会根据输入序列生成对应的回答 X a t X^t_a Xat。具体来说,Vicuna语言模型采用自回归的方式逐词生成回答。对于长度为 L L L 的序列,通过以下公式计算目标答案 X a X_a Xa 的概率:
p ( X a ∣ X v , X i n s t r u c t ) = ∏ i = 1 L p θ ( x i ∣ X v , X i n s t r u c t , < i , X a , < i ) p(X_a|X_v, X_{instruct}) = \prod_{i=1}^{L} p_{\theta}(x_i|X_v, X_{instruct,\lt i}, X_{a, \lt i}) p(Xa∣Xv,Xinstruct)=i=1∏Lpθ(xi∣Xv,Xinstruct,<i,Xa,<i)
其中,θ是可训练参数, X i n s t r u c t , < i X_{instruct,<i} Xinstruct,<i 和 X a , < i X_{a,<i} Xa,<i 分别是当前预测标记xi之前所有轮次中的指令和回答标记。
六、模型架构代码分析
代码链接:https://github.com/LLaVA-Annonymous/LLaVA
6.1 三种注意力机制
6.1.1 scaled_multihead_dot_product_attention
scaled_multihead_dot_product_attention,标准多头点积注意力。
计算公式:
-
输入投影:
- 查询(Query)、键(Key)、值(Value)通过线性层投影,并拆分为多头:
Q = rearrange ( q , ′ b s ( h d ) → b h s d K = rearrange ( k , ′ b s ( h d ) → b h d s V = rearrange ( v , ′ b s ( h d ) → b h s d Q = \text{rearrange}(q, 'b\ s\ (h\ d) \rightarrow b\ h\ s\ d \\ K = \text{rearrange}(k, 'b\ s\ (h\ d) \rightarrow b\ h\ d\ s \\ V = \text{rearrange}(v, 'b\ s\ (h\ d) \rightarrow b\ h\ s\ d Q=rearrange(q,′b s (h d)→b h s dK=rearrange(k,′b s (h d)→b h d sV=rearrange(v,′b s (h d)→b h s d
- 查询(Query)、键(Key)、值(Value)通过线性层投影,并拆分为多头:
-
注意力分数计算:
AttnWeight = softmax ( Q K T d + attn_bias ) \text{AttnWeight} = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + \text{attn\_bias}\right) AttnWeight=softmax(dQKT+attn_bias)softmax_scale
默认为 1 / d 1/\sqrt{d} 1/d,用于缩放点积结果。
-
掩码处理:
- 因果掩码(
is_causal=True
):上三角矩阵置负无穷(min_val
),强制未来位置不可见。 - Key填充掩码(
key_padding_mask
):将填充位置(False
)的分数置为负无穷。
- 因果掩码(
-
输出计算:
Output = AttnWeight ⋅ V \text{Output} = \text{AttnWeight} \cdot V Output=AttnWeight⋅V- 最后合并多头输出:
rearrange(out, 'b h s d → b s (h d)')
。
- 最后合并多头输出:
特点:
- 灵活性:支持自定义偏置(
attn_bias
)、填充掩码和因果掩码。 - 计算开销:显式计算注意力矩阵( O ( n 2 ) O(n^2) O(n2) 内存),适合小规模序列或需要注意力权重的场景。
6.1.2 Flash-Attn
计算公式:
-
输入解压(Unpad):
- 去除填充部分,仅处理有效token:
Q unpad = unpad ( Q , mask ) , K unpad , V unpad 同理 Q_{\text{unpad}} = \text{unpad}(Q, \text{mask}), \quad K_{\text{unpad}}, V_{\text{unpad}} \text{同理} Qunpad=unpad(Q,mask),Kunpad,Vunpad同理 - 重排为多头格式:
nnz (h d) → nnz h d
(nnz
为非零token数)。
- 去除填充部分,仅处理有效token:
-
高效注意力计算:
- 调用Flash-Attn的低级接口:
Output unpad = flash_attn_unpadded ( Q unpad , K unpad , V unpad , cu_seqlens , causal = reset_is_causal ) \text{Output}_{\text{unpad}} = \text{flash\_attn\_unpadded}(Q_{\text{unpad}}, K_{\text{unpad}}, V_{\text{unpad}}, \text{cu\_seqlens}, \text{causal}= \text{reset\_is\_causal}) Outputunpad=flash_attn_unpadded(Qunpad,Kunpad,Vunpad,cu_seqlens,causal=reset_is_causal) - 内部使用分块计算和IO优化,避免显式存储注意力矩阵。
- 调用Flash-Attn的低级接口:
-
输出重压缩(Pad):
- 将结果重新填充到原始序列长度:
Output = pad ( Output unpad , indices q ) \text{Output} = \text{pad}(\text{Output}_{\text{unpad}}, \text{indices}_q) Output=pad(Outputunpad,indicesq)
- 将结果重新填充到原始序列长度:
特点:
- 高效性:通过减少冗余计算(如填充部分)和内存访问优化,显著提升长序列性能。
- 限制:
- 不支持
attn_bias
,需通过掩码间接实现。 - 需要安装Flash-Attn库,依赖CUDA硬件。
- 不支持
6.1.3 Triton-Flash-Attn
计算公式:
-
输入投影:
- 类似标准多头注意力,但使用Triton内核处理:
Q , K , V = rearrange ( . . . ) , multiquery时复制K/V的头数 Q, K, V = \text{rearrange}(...), \quad \text{multiquery时复制K/V的头数} Q,K,V=rearrange(...),multiquery时复制K/V的头数
- 类似标准多头注意力,但使用Triton内核处理:
-
注意力计算:
- 调用Triton实现的Flash-Attn:
Output = triton_flash_attn ( Q , K , V , attn_bias , causal = reset_is_causal ) \text{Output} = \text{triton\_flash\_attn}(Q, K, V, \text{attn\_bias}, \text{causal}= \text{reset\_is\_causal}) Output=triton_flash_attn(Q,K,V,attn_bias,causal=reset_is_causal) - 支持动态掩码(如
key_padding_mask
转换为attn_bias
)。
- 调用Triton实现的Flash-Attn:
-
输出合并:
- 直接合并多头输出:
b s h d → b s (h d)
。
- 直接合并多头输出:
特点:
- 硬件适配:利用Triton编译器优化GPU内核,适合特定架构(如Ampere)。
- 限制:
- 不支持dropout和返回注意力权重。
- 依赖Triton版本。
6.1.4 对比总结
特性 | 标准多头注意力 | Flash-Attn | Triton-Flash-Attn |
---|---|---|---|
内存效率 | 低(显式矩阵) | 高(解压输入) | 中(内核优化) |
支持掩码 | 全部 | 仅因果/填充掩码 | 通过attn_bias 转换 |
长序列优化 | 无 | 是 | 是 |
依赖库 | 无 | Flash-Attn | Flash-Attn + Triton |
适用场景 | 小规模/需权重 | 大规模无padding | 特定GPU架构 |
6.2 模型训练
6.2.1 训练流程
LLaVA的训练流程可分为以下核心步骤:
- 参数解析:解析模型/数据/训练三组参数。
- 模型加载:加载预训练语言模型(LLaMA或MPT)和视觉编码器(CLIP)。
- 多模态适配器初始化:连接视觉和语言模态的投影层。
- 数据预处理:处理图文对话数据,生成tokenized输入。
- 训练循环:使用FSDP(完全分片数据并行)优化多模态模型。
6.2.2 关键逻辑
(1) 模型架构
- 视觉编码器:CLIP视觉塔(
vision_tower
)提取图像特征。 - 语言模型:LLaMA或MPT作为文本生成主干。
- 多模态投影层 (
mm_projector
):- 将图像特征映射到语言模型空间:
[image_seq, d_vision] -> [image_seq, d_text]
- 结构:两层MLP(线性层 + GELU激活 + 线性层):
nn.Sequential( nn.Linear(vision_hidden_size, text_hidden_size), nn.GELU(), nn.Linear(text_hidden_size, text_hidden_size) )
- 将图像特征映射到语言模型空间:
(2) 关键参数配置
@dataclass
class ModelArguments:
vision_tower: str = "openai/clip-vit-large-patch14" # CLIP视觉编码器
mm_vision_select_layer: int = -1 # 使用CLIP最后一层特征
@dataclass
class DataArguments:
image_aspect_ratio: str = 'pad' # 图像处理策略(pad/keep)
@dataclass
class TrainingArguments:
freeze_mm_mlp_adapter: bool = False # 是否冻结投影层
(3) 数据预处理——文本
# 示例输入(JSON格式)
{
"conversations": [
{"from": "human", "value": "<image>Describe this image."},
{"from": "gpt", "value": "A cat sitting on a couch."}
],
"image": "cat.jpg"
}
文本模板化:
- 添加角色标记(
### Human:
、### GPT:
)。 - 替换
<image>
为图像token(如<im_patch>*256
)。 - 使用LLaMA tokenizer分词,对Human部分的标签掩码(
IGNORE_INDEX=-100
)。 - 若启用
mm_use_im_start_end
,添加特殊标记:<im_start> + <im_patch>*image_token_len + <im_end>
。
对话风格:
- Vicuna风格(默认):
### Human: <image>Describe this image. ### GPT: A cat sitting on a couch.
- MPT风格:
- 使用
<im_start>
和<im_end>
标记。
- 使用
# 处理对话数据,添加信号、连接对话、分词,并对目标张量进行掩码。
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.version == "v1":
return preprocess_v1(sources, tokenizer)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
tokenizer)["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)
(4) 数据预处理——图像
- 保持宽高比 (
image_aspect_ratio='keep'
):- 缩放短边至224-448像素,保持比例。
- 填充为方形 (
image_aspect_ratio='pad'
):- 用均值颜色填充图像为正方形。
- ViT特征提取:
- 图像被分割为14x14的patch,输出序列长度=
(H//14)*(W//14)
。
- 图像被分割为14x14的patch,输出序列长度=
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
# ...
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# ...
# 图像处理逻辑
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
image_folder = self.multimodal_cfg['image_folder']
processor = self.multimodal_cfg['image_processor']
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
if self.multimodal_cfg['image_aspect_ratio'] == 'keep': # 保持宽高比
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': # 填充为方形
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else: # ViT特征提取**:图像被分割为14x14的patch
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.multimodal_cfg, cur_token_len)
else:
# ...
return data_dict
(4) 投影层训练
tune_mm_mlp_adapter
和freeze_mm_mlp_adapter
是LLaVA中平衡视觉-语言模态交互的关键开关,用于控制**多模态投影层(MM_MLP Adapter)**的训练状态。
- 冻结投影层:保持预训练的投影关系,适合推理或稳定性优先的场景。
- 不冻结投影层:允许投影层适应下游任务,提升多模态融合能力。
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False) # 冻结主干
for p in model.get_model().mm_projector.parameters(): # 仅训练投影层
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters(): # 冻结投影层
p.requires_grad = False
典型使用场景
配置组合 | 训练目标 | 适用阶段 |
---|---|---|
freeze_mm_mlp_adapter=True | 固定投影层,优化其他部分 | 视觉-语言对齐初步测试 |
freeze_mm_mlp_adapter=False + tune_mm_mlp_adapter=True | 仅训练投影层 | 轻量级多模态适配 |
freeze_mm_mlp_adapter=False + freeze_backbone=False | 联合训练所有参数 | 端到端微调 |
七、LLaVA 训练
7.1 数据生成
LLaVA的训练数据生成主要依赖于语言模型GPT-4来生成多模态语言-图像指令跟随数据。研究者们利用GPT-4将图像-文本对转换为适当的指令跟随格式。首先,对于图像
X
v
X_v
Xv 和其相关的标题
X
c
X_c
Xc,研究者创建了一组问题
X
q
X_q
Xq用来指示助手描述图像内容。然后,他们提示GPT-4来编制这类问题列表。
为了将图像编码成其视觉特征,使用了两种类型的符号表示将图像编码为 GPT的可识别序列:
- (i)标题通常从各种角度描述视觉场景;
- (ii)边界框通常在场景中定位对象,并编码对象概念及其空间位置。
通过将图像内容编码为符号表示(标题和边界框),提示GPT-4生成与视觉内容相关的指令和回答,涵盖对话、详细描述和复杂推理三种类型。对于每种类型,首先手动设计一些示例,作为数据收集过程中唯一的人工注释,相当于few-shot。
7.2 训练过程
7.2.1 第一阶段:预训练(θ = W)
保持视觉编码器和LLM权重冻结,仅最大化投影矩阵W的似然性。使得视觉特征
H
v
H_v
Hv可以与LLM的词嵌入对齐(这个阶段可以理解为训练一个与冻结LLM兼容的视觉分词器)
将CC3M过滤为595K图像-文本对,每个样本可以视为一个单轮对话。对于图像
X
v
X_v
Xv,随机采样一个问题
X
q
X_q
Xq,要求助手简要描述图像。真实预测答案
X
a
X_a
Xa是原始标题。
7.2.2 第二阶段:端到端微调(θ = {W, ϕ})
保持视觉编码器权重冻结,并继续更新LLaVA中预训练权重的投影层和LLM。
考虑两种特定的使用场景:
- 多模态聊天机器人。通过在158K语言-图像指令遵循数据上进行微调来开发聊天机器人。在三种回答类型中,对话是多轮的,而其他两种是单轮的。它们在训练中均匀采样。
- 科学问答。在ScienceQA(第一个大规模的多模态科学问题数据集)基准上研究,其答案标注有详细的解释。每个问题提供自然语言或图像形式的上下文。助手用自然语言提供推理过程,并从多个选择中选择答案。训练中将数据组织成单轮对话、问题和上下文作为 X i n s t r u c t X_{instruct} Xinstruct,推理和答案作为 X a X_a Xa。
八、实验结果
8.1 多模态聊天能力
LLaVA在多模态聊天任务中展现出与多模态GPT-4相似的推理和回答能力,即使在未见过的图像/指令上也能提供合理的回答,优于BLIP-2和OpenFlamingo等模型。
8.2 定量评估
通过构建LLaVA-Bench(COCO)和LLaVA-Bench(In-the-Wild)两个基准测试,对LLaVA的性能进行了系统评估。结果显示,LLaVA在复杂推理问题上表现出色,整体性能接近 GPT-4。
8.3 ScienceQA任务
在ScienceQA数据集上,LLaVA微调后达到了90.92%的准确率,接近当时的最先进水平。通过与GPT-4进行模型集成,进一步将准确率提升至92.53%,创造了新的最先进记录。这是因为有些问题并不需要图像上下文来得出正确答案。GPT-4裁判可以识别出这些情况,并纠正LLaVA的一些错误。
实验中使用两种方案来结合LLaVA和GPT-4的结果:
- (i)GPT-4补充。每当GPT-4无法提供答案时,使用LLaVA的预测作为最终结果。
- (ii)GPT-4作为裁判。每当GPT-4和LLaVA产生不同答案时,再次提示GPT-4,要求它根据问题和两个结果提供自己的最终答案。其精神与CoT类似,但借助了另一个模型的外部知识。
九、启发
- 多模态模型的潜力:论文展示了通过指令微调提升多模态模型性能的可能性,为后续研究提供了新的方向,即如何更好地结合视觉和语言信息来提高模型的通用性和适应性。
- 数据生成的重要性:利用GPT-4生成多模态指令遵循数据的方法为解决多模态数据匮乏问题提供了新思路,未来可以探索更多高效的数据生成策略。
- 模型集成的创新:将LLaVA与GPT-4进行集成的尝试为多模态模型的优化提供了新的思路,未来可以探索更多模型集成方法,以进一步提升多模态模型的性能。