【Transform 2】ViT 论文学习笔记+训练微调

一、背景

        前面学习了 Transformer,本文再来学习 ViT,阅读论文的基础上结合代码实践。ViT 刚发布的时候本人还在读研,ViT 一出来 CV 领域也开始频繁出现 Transformer 的身影,无奈当时已经确定了研究方向和技术路线,就没有再研究 ViT。一晃几年过去,现在才算是开始学习 ViT,虽然可能用到的机会不多,但学习一下论文思想和模型结构相信还是很有用的。

二、论文

Paper:[2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

1. 模型结构

        核心思想:将图像切成一个个小块(patch),经过线性投影转换为 token 序列,每个 token 拼接一个位置编码表示 patch 之间的位置关系,[class] token 标识图像类别,然后将这个序列送入 Transformer 的 Encoder,MLP 中的分类头负责预测类别信息。

        下面结合代码分析处理过程,代码使用的 pytorch 版本的 ViT:https://github.com/kentaroy47/vision-transformers-cifar10

        ViT 主体实现:

class ViT(nn.Module):
    """
    image_size: 32*32*3
    patch_size: 4
    dim: 各层输出维度,Transformer中的d_model=512
    depth: Transformer的堆叠数,N
    heads: 多头注意力头数
    pool: 类别预测的方式,cls: 单独添加类别token(论文中的方式,默认);mean: 所有token的特征平均
    dim_head: 每个注意力头的维度,d_model/heads
    """
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        
        # 确保能切分出整数个patch
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        # patch个数
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 每个patch的大小=P^2 * C
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        
        # 图像嵌入层,维度映射为d_model
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        
        # 位置嵌入层
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # class token 层
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        # dropout层
        self.dropout = nn.Dropout(emb_dropout)
        # Transformer层
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        
        # pool 类型
        self.pool = pool
        # 恒等映射层,对输入做原样输出,这里貌似没啥作用
        self.to_latent = nn.Identity()
        
        # MLP分类层,先做LN,再将d_model映射为类别数进行预测
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # 图像嵌入 (img: [128, 3, 32, 32] --> x: [128, 64, 512])
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        #为当前批数据添加class token (self.cls_token: [1, 1, 512],cls_token: [128, 1, 512])
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 在patch序列中嵌入class token (x: [128, 64, 512] --> x: [128, 65, 512])
        x = torch.cat((cls_tokens, x), dim=1)
        # 每个token添加位置信息
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        
        # 取出序列中的第一个token,即class token
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 返回分类预测
        return self.mlp_head(x)

2. 图像嵌入

        实现如何将图像数据转换为 Transformer 需要的序列数据。

              

2.1 patch embedding

        对于一张输入图像 ,将其切分成 patch 序列,每个 patch ,共  个,由 patch 组成的序列 


                
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值