文章目录
摘要
本周学习了Vision Transformer (ViT) 的基本原理及其实现,并完成了基于PyTorch的模型训练、验证和预测任务。深入理解了ViT如何将图像分割成patch作为输入序列,并结合Transformer Encoder处理。通过迁移学习在花类数据集上训练模型,并验证了模型在预测任务中的优越性能。
Abstract
This week, I studied the fundamental principles and implementation of Vision Transformer (ViT) and completed model training, validation, and prediction tasks using PyTorch. I gained a deep understanding of how ViT splits an image into patches as input sequences and processes them using the Transformer Encoder. By leveraging transfer learning, I trained the model on a flower dataset and validated its superior performance in prediction tasks.
Vision Transformer
1 原理
- 数据处理
我认为ViT的关键在于理解怎么将图片当作一个序列输入进模型之中。我们先看看ViT整体结构图,如下图所示
论文中提到将 224x224x3 的图像作为输入,将图像分为 16x16x3 大小的patch,也就是说将输入图像分为了 224 × 224 × 3 16 × 16 × 3 = 196 \frac{224×224×3}{16×16×3}=196 16×16×3224×224×3=196 个patch。其中每个patch拉直之后的维度为 16×16×3=768维,也就是Linear Projection of Flattened Patches层下面分割的小图像。
在具体实现中,使用卷积核大小为 16x16x3 、步距为16、卷积核个数为768的卷积层,就能将3维图像转换为Transformer所需要的输入token[组数,维度]。
-
全连接层
上述[196,768]的token将传入Linear Projection of Flattened Patches层,该层是 768x768 的全连接层,该层输出认为 196x768 。 -
位置编码
将经过全连接层后的输出进行位置编码,其位置编码和Transformer中的时序编码有异曲同工之妙,前者可以通过位置编码表示出token之间关于原输入图像的一些位置信息,后者可以表示输入先后的时序信息。
该模型位置编码通过类似于坐标的形式表达,直接于输入相加,不改变维度大小。如下图所示:
进行位置编码后,还需要加上一个特殊字符(最左输入0*),输入总组数从之前的196变为197,传入Transformer Encoder的token为[197,768]。
- Transformer Encoder
ViT采用的是Transformer中编码器进行叠加,但其中的参数数量有所不同。
经过位置编码和加入特殊字符的token[197,768]传入编码器,首先经过层归一化,再经过多头自注意力。这里的多头自注意力是采用12个头,也就是将768维分为12份,每份(Q、K、V)64维度,计算之后再进行合并为768维。
ViT中的编码器仍是采用残差连接,再经过一次层归一化后,就进入单个Transformer Encoder的最后一层MLP(多层感知机)。MLP将经过多头自注意力的输出维度升高4倍,即从768变为3072,最后再将维度降至768维
ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。
- 输出
ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。
最后,通过全连接层和softmax进行概率输出即可
2 代码
在理解完ViT的原理之后,我们来看看PyTorch代码如何实现。这里以ViT-base模型,输入图像 224x224x3,patch大小 16x16x3 为例
花类数据集:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzy
训练模型代码如下,需要自行更改数据集路径和权重路径。
import os
import math
import argparse
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
if os.path.exists("../weights") is False:
os.makedirs("../weights")
tb_writer = SummaryWriter()
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(