谷歌「钞」能力 | 最强ViT模型来了!220亿参数

导读

 

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。

本文目录

52 扩展到220亿参数的巨大视觉 Transformer
(来自谷歌,含 ViT 作者)
52 ViT-22B 论文解读
52.1 背景和动机
52.2 三句话概括 ViT-22B 模型的架构
52.3 ViT-22B 的实现方法:异步并行线性操作计算
52.4 数据集和超参
52.5 图像分类迁移性能
52.6 密集预测性能
52.7 ViT-22B 与人类感知的一致性

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。

52 扩展到220亿参数的巨大视觉 Transformer

论文名称:Scaling Vision Transformers to 22 Billion Parameters

论文地址:

https://arxiv.org/pdf/2302.05442.pdf

52.1 背景和动机

与自然语言处理类似,视觉预训练大模型也提高了在各种视觉任务的性能。这不仅依赖于更大的数据集,更大的可扩展的视觉架构,也同样依赖于新的训练策略。

但尽管如此,视觉模型从规模和效果上而言,远远落后于语言模型。迄今为止最大的密集视觉模型只有 4B 参数 ViT[1],而入门级的语言模型通常包含超过 10B [2][3]个参数以上,最大的密集语言模型有 540B 个参数。稀疏模型也展示出同样的趋势,其中语言模型的参数超过了一万亿[4],但最大的稀疏视觉模型只有约 15B[5]。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT-22B,并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。ViT-22B 的质量通过分类和下游任务实验进行评估,在这些任务中它达到或提高了当前的最先进水平。

通过多模态训练一个 text tower 来匹配视觉特征,ViT- 22B 在 ImageNet 上实现了 85.9% 的 zero-shot 精度。此外,该模型是一个很好的老师——用作蒸馏目标,作者训练了一个 ViT- B 学生模型,其在 ImageNet 上达到了88.6% 的 SOTA 精度。

除了分布的改进、ViT-22 的可靠性、不确定性估计和公平性都取得了提升。更重要的是,ViT-22 的特征更好地与人类的感知保持一致,实现了之前未见过的 87% 的形状偏差 (shape bias)。

52.2 三句话概括 ViT-22B 模型的架构

ViT-22B 是一种基于 Transformer 的模型,其架构的设计类似于原始 ViT,但包含以下3个修改,以提高效率和大规模训练的稳定性。如下图1所示,用三句话概括分别是:

  • 并行设计

  • Query 和 Key 的归一化

  • 省略 bias

8e0b2e91f1d0529f3953417d2624db94.jpeg
图1:ViT-22B 模型的架构

第1句话,把 ViT 中的 Self-Attention 层换成了一个 Self-Attention 层加一个 MLP 层的并行设计,公式如下:

d2a6ca96ac93ebf2475f9112edecbda4.png

注意这里作者不是简单地将两个模块相加,而是使用了一个并行化技巧,即:用于 Self-attention 中的 Query,Key,Value 计算的矩阵乘法和 MLP 的第1个线性层被融合到一个单独的操作中;用于 Self-attention 中的输出投影和 MLP 的第2个线性层也被融合到一个单独的操作中。这种方法最初是由 PaLM[6] 提出的,该技术在不降低性能的情况下将最大模型的训练速度提高了 15%。

第2句话,Self-Attention 中 Query 和 Key 的计算过程添加了归一化。在以往 ViT 扩展的工作中,作者观察到了在训练了几千步之后,training loss 发散了的情况,这毫无疑问使得我们无法训练大型 ViT 模型,尤其是在 8B 参数左右的模型中观察到这种不稳定性。

出现这种现象的原因是:注意力矩阵的值异常 (近似于 one-hot,有的地方注意力很大,其他位置几乎为零) 引起的,这导致注意力权重的熵接近于零。为了解决这个问题,作者采用[7]的方法,将 LayerNorm 应用于 Self-Attention 中 Query 和 Key 的计算过程。具体可以写成:

式中,  是 query/key 的维度,  是 layer normalization,  分别是 Query 和 Key 的权重矩阵。对 8B 参数模型的影响如图2所示,其中归一化防止了注意力矩阵的值不受控的异常而导致的训练发散。

c81bfb032e29ccc529bceb0fb81cb082.jpeg

图2:Self-Attention 中 Query 和 Key 的计算过程添加归一化对 8B 参数模型的影响

第3句话,省略 bias。遵循 PaLM 的做法,从 QKV 投影中去除 bias,并使用没有 bias 项和 centering 的 Layer Normalization[8]。这在不影响训练质量的前提下提升了训练速度。但是,与 PaLM 不同的是,作者对所有 MLP 层使用了 bias 项,因为观察到质量得到了改善,且训练速度没有下降。

ViT-22B 使用 14×14 大小的 Patch,输入图像分辨率为 224×224。类似于原始 ViT,ViT-22B 也采用了可学习的位置编码。在对高分辨率图像 (不同数量的 Patch) 进行微调期间,也对预训练的位置编码执行二维插值。ViT-22B和 ViT-G,ViT-e 的超参数对比如下图3所示。

cddb2757145845d2634a7474e45172ee.png

图3:ViT-22B和 ViT-G,ViT-e 的超参数对比

ViT-22B 的 model card 如下图所示。

624c90358a339f28fb623e429429b5b1.jpeg7999269bc0c7a4875a42231f77e502bb.jpeg

图4:ViT-22B 的 model card

52.3 ViT-22B 的实现方法:异步并行线性操作计算

ViT-22B 基于 JAX 框架和 FLAX,Scenic 库,它同时利用了模型和数据的并行性。作者使用了 jax.xmap 这个 API,其对所有中间体的分片 (例如权重和激活) 以及芯片间通信提供了明确的控制。

这个高效的实现如下图5和6所示。该怎么理解这个过程呢?

ff80bca9ba47d4f9bb8b4465a4ada9bc.jpeg

图5:矩阵 A 在不同设备之间按行切分

8c6af7f5b943d07472af71e9cced6df4.jpeg

图6:矩阵 A 在不同设备之间按列切分

为了说明这个过程, 这一段我们介绍下它的原理。考虑计算 这个矩阵乘法运算。

比如说我们有  台设备, 设  分别是  的第  个 Block , 那么  分别放在第  台设备上, 如上图5和图6所示。设  。那么现在等于说我们有  个矩阵  的分块。要把这么多块均匀地放在  台 devices 上面, 就 有两种放法。为了方便说明, 这里我们假设  吧。那么现在矩阵  就被分成了  的分块。

  • 第1种:  的每一行的块放在同一台 device 上, 即:  放在同一台上面, 对应图5。

  • 第2种:  的每一列的块放在同一台 device 上, 即:  放在同一台上面, 对应图6。

对于第1种情况, 计算  需要  次  的通信, 因为  的维度是 , 所以就需要  float 的通信。

对于第2种情况, 计算  需要  次  中间值的通信, 因为  中间值的维度是 , 所以就 需要  float 的通信。

当  时, 采用哪一种计算方式都可以。当  时, 比如 MLP 的输出层有 , 那么这个时候采用第2种情况的计算方法。

注意, 以上通信的过程与计算过程是异步的。即, 当计算一层时, 设备可以开始通信下一层的权重, 从而最大限度地减少通信开销。

作者将芯片组织成大小为  的  逻辑网格,其中  是数据轴的大小,  是模型轴的大小。然后, 对于  组中的每个组,  个设备获得相同 Batch 的图像, 每个设备只保留  的激活值, 并负责计算输出  的  。

52.4 数据集和超参

ViT-22B 在 JFT的一个版本上训练,训练集包含大约 4B 图像,并采用 Sigmoid 交叉熵损失以多标签分类方式使用所有分配的标签。

ViT-22B 把图片分为 14×14 大小的块,使用 65k 的 Batch Size 训练 177k steps (3 Epochs),初始学习率,reciprocal square-root learning rate schedule,以及 10k 的线性学习率 warmup,和 30k 的的线性学习率 cooldown。上游预训练的权重衰减 head 设为3.0,body 设为 0.03。

52.5 图像分类迁移性能

Linear Probing 实验结果

作者在 ImageNet 上使用了10个周期的动量 SGD,分辨率为 224px,使用 mild random cropping 和 horizontal flipping 作为唯一的数据增强策略,而没有进一步的正则化措施。

如下图7所示是 ViT-22B 的 Linear Probing 实验结果,虽然收益不高,但在这个尺度上仍有显著的改善。而且图7还说明,,像 ViT-22B 这样的大模型的 Linear Probing 可以接近或超过具有小模型的高分辨率完全微调的性能,而 Linear Probing 通常成本更低。

def50db0de907c641a0d6c902751439e.jpeg

图7:ImageNet 上 ViT-22B 的 Linear Probing 实验结果

作者进一步在细粒度分类数据集 iNaturalist 2017 上测试线性可分性。iNaturalist 2017 有5,089个细粒度类别,属于13个大类。与 ImageNet 不同,不同类别的图像数量是不平衡的。概念的长尾分布对分类更具挑战性。作者将 ViT- 22B 与其他 ViT 变体进行比较,并测试了 224px 和 384px 的输入分辨率。结果如图8所示,可以观察到 ViT- 22B 明显优于其他变体,特别是在标准的 224px 输入分辨率下,这表明 ViT-22B 中大量的参数对于从图像中提取详细信息是有用的。

5e5377814491216e443f8f5f9661ee2e.png

图8:iNaturalist 2017 上 ViT-22B 的 Linear Probing 实验结果

Out-of-distribution 实验结果

定义: 在分类任务中,给定测试图片,若模型在训练阶段模型见过或类似的图片,则能正确分类;但如果与训练集完全不相关,也会被强制判定为训练集类别中的一种,这种情况是不合理的。OOD 算法希望能判断的分布状况是否与训练集一致,若一致,则称为 in-distribution (ID),否则称为 out-of-distribution (OOD)。
使用场景举例: 在 MNIST 上训练的一个分类模型,然后,输入一张“马”的图片,会被归类为数字 0~9,这是错误的。此时,MNIST 数据集就是 in-distribution,相对于 ID 而言,“马”是 out-of-distribution。

作者构建了从 JFT 到 ImageNet 的标签映射,以及从 ImageNet 到不同分布外数据集的标签映射,即 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A。ImageNet-R 和 ImageNet-A 使用相同的 ImageNet 的200个标签子空间,而 ObjectNet 有313个类别,其中只考虑与 ImageNet 标签空间重叠的113个类别。那么这样以后,JFT 上面训练出的模型的预测结果就可以转化到其他数据集上面了,这就为模型的 Out-of-distribution 能力提供了一种验证手段。

Out-of-distribution 的评估一般是首先在一个较大的数据集上面 (比如 ImageNet) 做预训练,然后在 ImageNet-R ,ImageNet-A 等数据集上直接评估其性能。实验结果如下图9所示。作者做了两种实验,其一 (图9上半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。其二 (图9下半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ImageNet 数据集上微调,最后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。

0d212567deb08ff83e5c826e0e7a83b7.jpeg

图9:OOD 实验结果。标注 "ema" 的模型使用 Polyak 平均进行微调

把图9中 ObjectNet 的实验结果画出来就是图10.

e22b588d507ec91b0051e64779538053.jpeg

图10:OOD 中 ObjectNet 实验结果

从图9和10中可以得出结论:把模型做大可以增加 Out-of-distribution 的性能。这适用于只看过 JFT 图像的 ViT-22B 模型,以及在 ImageNet 上做了微调过之后的模型。在这两种情况下,ViT-22B 在更大的模型上都延续了 OOD 性能更好的趋势。即使 ImageNet 的性能饱和,但从图10中也可以看到 ObjectNet 上的精度从 ViT-e/14 到 ViT-22B 的显著提升。

52.6 密集预测性能

密集预测任务的迁移性能也是评价一个 Backbone 模型的关键因素。作者通过语义分割和单目估计任务评价 ViT-22B 捕获的几何和空间信息的质量。

语义分割任务使用 ViT-22B 作为骨干模型,UperNet 作为分类头,在 ADE20K, Pascal Context 和 Pascal VOC 三个数据集上进行评测。如下图11所示,作者将 ViT-22B 与 DeiT-III ,ViT-G 进行了比较,且只使用了一部分的训练数据。作者使用线性解码器和端到端微调。从图11中可以观察到,当只使用少量的数据时,ViT-22B 骨干更好。例如,当只对1200张图像 (即1/16) 的 ADE20K 训练数据进行微调时,ViT-22B 达到了 44.7 mIoU 的性能,比 DeiT-III Large 提高了8.6 mIoU,比 ViT-G 提高了 2.3 mIoU。当数据量较大时,ViT-G 和 ViT-22B 的性能趋于一致。

36bbc0dac66a193541aaf19102d06628.jpeg

图11:ADE20K Fewshot 语义分割实验结果

单目估计任务使用 ViT-22B 作为骨干模型,Dense Prediction Transformer (DPT) 作为单目估计头,或者仅仅使用一个简单的线性解码器作为估计头。实验在 Waymo Open real-world driving 数据集上进行评测。如下图12所示,上半部分 (DPT解码器) 的结果可以观察到,与不同的主干相比,使用 ViT-22B 骨干网络得到了最好的性能。通过将 ViT-22B 骨干与 ViT-e 进行比较,发现扩展架构可以提高性能。使用线性解码器,可再次观察到使用 ViT-22B 骨干模型可获得最佳性能。DPT 和线性解码器之间的差距表明,虽然在 ViT 特征中保留了足够的几何信息,但只有一部分可被普通的线性解码器利用到。

ee366c4036053b339683a79368c289a1.jpeg

图12:Waymo Open dataset 单目估计实验结果

在研究规模的影响时,除了下游任务性能之外,还有一些重要的方面需要考虑。比如一个主干网络的与人类感知的一致性。

52.7 ViT-22B 与人类感知的一致性

ViT-22B 分类决策与人类分类决策的一致性如何?通过这篇论文 "Partial success in closing the gap between human and machine vision" 的方法,作者评估了在 ImageNet 上以不同分辨率 (224,384,560) 微调的三个 ViT-22B 模型。在所有指标中,ViT-22B-224 具有最高的 OOD 稳健性 (图13 (a)), ViT-22B -384 与人类分类精度最接近(图13(b)), ViT-22B-560 具有最大的错误一致性 (即大多数类似人类的错误模式,图13(d))。ViT-22B 模型在视觉模型中有最高的 shape bias 记录:而大多数模型都有很强的 texture bias (20-30% 的 shape bias + 70-80% 的 texture bias),人类的 shape bias 为 96% + 4%,而 ViT-22B -384 达到了前所未见的 87% shape bias + 13% texture bias。总体而言,ViT-22B显著改善了人类视觉物体识别的对齐。

6233748e034c708c57467436d880786d.jpeg

图13:ViT-22B 与人类感知的一致性

除此之外,作者还对 ViT-22B 的公平性,鲁棒性,可靠性和校准。进行了相关实验,详细细节读者可以参考原始论文。作者发现,随着模型尺寸的增加,会出现有利的特性。

总结

本文提出了 ViT-22B,目前最大的视觉 Transformer 模型,有220亿个参数。作者证明,通过对原始架构进行三点修改,可以实现出色的硬件利用率训练稳定性,从而产生一个在几个基准上实现 SOTA 的模型。作者的评估进一步表明,与现有模型相比,ViT-22B 在形状和纹理偏差方面更符合人类,并在公平性和稳健性方面具有优势。

参考:

基于深度模型Out of Distribution(OOD)基础技术路线研究:https://blog.youkuaiyun.com/Aqrose_666/article/details/124592372

参考

  1. ^PaLI: A Jointly-Scaled Multilingual Language-Image Model

  2. ^Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

  3. ^Unifying language learning paradigms

  4. ^Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

  5. ^Scaling Vision with Sparse Mixture of Experts

  6. ^PaLM: Scaling language modeling with pathways

  7. ^Intriguing Properties of Transformer Training Instabilities

  8. ^Root Mean Square Layer Normalization

国内首个自动驾驶学习社区

近1000人的交流社区,和20+自动驾驶技术栈学习路线,想要了解更多自动驾驶感知(分类、检测、分割、关键点、车道线、3D目标检测、多传感器融合、目标跟踪、光流估计、轨迹预测)、自动驾驶定位建图(SLAM、高精地图)、自动驾驶规划控制、领域技术方案、AI模型部署落地实战、行业动态、岗位发布,欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频,期待交流!

cdff8971950d593cb4d45408453bafdf.jpeg

自动驾驶之心】全栈技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦目标检测、语义分割、全景分割、实例分割、关键点检测、车道线、目标跟踪、3D目标检测、BEV感知、多传感器融合、SLAM、光流估计、深度估计、轨迹预测、高精地图、NeRF、规划控制、模型部署落地、自动驾驶仿真测试、硬件配置、AI求职交流等方向;

c3e7580b860e8dd7a538c2e02c28accc.jpeg

添加汽车人助理微信邀请入群

备注:学校/公司+方向+昵称

给定引用中未提及ViT模型参数情况。一般而言,ViT模型参数主要涉及以下几个方面: ### 图像块嵌入参数 - **图像块大小(Patch Size)**:将输入图像分割成固定大小的图像块,例如常见的16x16、32x32等。这个参数决定了图像被分割的粒度,较小的图像块可以捕捉更多的细节信息,但会增加计算量和参数数量。 - **嵌入维度(Embedding Dimension)**:每个图像块被映射到一个固定维度的向量,这个维度就是嵌入维度。例如,在ViT-Base模型中,嵌入维度通常为768。较大的嵌入维度可以提供更丰富的特征表示,但同样会增加模型的复杂度。 ### 位置编码参数 - **位置编码方式**:为了让模型能够感知图像块的位置信息,需要对每个图像块的嵌入向量添加位置编码。常见的位置编码方式有可学习的位置编码和固定的正弦位置编码。 - **位置编码维度**:位置编码的维度通常与嵌入维度相同,以确保可以直接与图像块嵌入向量相加。 ### 多头自注意力机制参数 - **头的数量(Number of Heads)**:多头自注意力机制通过多个头并行地计算注意力,每个头关注不同的特征子空间。例如,ViT-Base模型通常使用12个头。头的数量越多,模型可以捕捉到的特征信息越丰富,但也会增加计算量。 - **键、查询、值矩阵维度**:在自注意力计算中,需要将输入的嵌入向量分别映射到键(Key)、查询(Query)和值(Value)矩阵。这些矩阵的维度通常与嵌入维度相关。 ### 多层感知机(MLP)参数 - **隐藏层维度**:MLP通常包含一个或多个隐藏层,隐藏层的维度决定了模型的非线性表达能力。例如,在ViT-Base模型中,MLP的隐藏层维度通常为嵌入维度的4倍。 - **激活函数**:MLP中常用的激活函数有GELU(Gaussian Error Linear Unit)等。 ### 分类头参数 - **输出维度**:分类头的输出维度取决于具体的任务,例如在图像分类任务中,输出维度等于类别数量。 以下是一个简单的PyTorch代码示例,展示了如何定义一个基本的ViT模型结构: ```python import torch import torch.nn as nn class ViT(nn.Module): def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 patch_dim = 3 * patch_size ** 2 self.patch_embedding = nn.Linear(patch_dim, dim) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth ) self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) ) def forward(self, x): patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous().view(x.shape[0], -1, 3 * patch_size ** 2) x = self.patch_embedding(patches) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding x = self.transformer(x) x = x[:, 0] x = self.mlp_head(x) return x # 示例使用 image_size = 224 patch_size = 16 num_classes = 10 dim = 768 depth = 12 heads = 12 mlp_dim = 3072 model = ViT(image_size, patch_size, num_classes, dim, depth, heads, mlp_dim) input_tensor = torch.randn(1, 3, image_size, image_size) output = model(input_tensor) print(output.shape) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值