Transformer——Q104 视觉Transformer中Patch Embedding的参数量计算(图像尺寸 H×W,Patch大小 P)

该问题归类到Transformer架构问题集——架构变体——跨模态扩展。请参考LLM数学推导——Transformer架构问题集

1. 背景知识:视觉 Transformer 与 Patch Embedding 的诞生逻辑

视觉 Transformer(ViT)是 Transformer 架构在计算机视觉领域的延伸。传统卷积神经网络(CNN)通过局部卷积操作提取图像特征,而 Transformer 的自注意力机制能捕捉全局依赖关系。但直接将 Transformer 应用于图像面临严峻挑战:一幅 224 \times 224 \times 3 的图像包含 150,528 个像素,若将每个像素视为一个 “词”,序列长度极长,计算量呈指数级增长,超出硬件处理能力。

为解决这一问题,ViT 提出将图像分割为多个固定大小的 Patch(块)。每个 Patch 作为一个 “视觉词”,大幅缩短序列长度。例如,将 224 \times 224 的图像分割为 16 \times 16 的 Patch,仅需14 \times 14 = 196 个 Patch,使 Transformer 能够高效处理图像数据。Patch Embedding 的核心任务,就是将这些 Patch 转化为适合 Transformer 处理的嵌入向量,而参数量计算是理解该模块复杂度与设计合理性的关键。

2. 技术原理:从图像分块到参数量的数学推导

2.1 图像分块与特征拉平

假设输入图像尺寸为 H \times W,通道数为 C(如 RGB 图像 C = 3),Patch 大小为 P \times P。图像被均匀划分为 \frac{H}{P} \times \frac{W}{P} 个 Patch(要求 H

### 原理 在时空Transformer中,Patch Embedding将二维图像矩阵转换为一维向量序列。其核心在于把图像分割成多个不重叠的小块(Patch),然后将每个Patch通过线性投影转换为固定长度的向量。这样做是因为Transformer模型通常处理的是序列数据,而图像本身是二维结构,通过Patch Embedding可以将图像数据转换为适合Transformer处理的序列形式,同时在转换过程中保留图像的空间结构信息[^1]。 ### 作用 - **维度转换**:将二维图像数据转换为一维向量序列,使得图像数据能够作为输入被时空Transformer模型处理。 - **信息保留**:在转换过程中保留图像的空间结构信息,有助于后续模型捕捉图像中的局部和全局特征。 - **计算效率**:通过合理的分块策略,平衡模型性能和计算成本,避免直接处理高分辨率图像带来的巨大计算量。 ### 实现方式 以下是一个简单的Python代码示例,使用PyTorch实现基本的Patch Embedding: ```python import torch import torch.nn as nn class PatchEmbedding(nn.Module): def __init__(self, image_size, patch_size, in_channels, embed_dim): super(PatchEmbedding, self).__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) * (image_size // patch_size) self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # [B, C, H, W] -> [B, embed_dim, H', W'] x = x.flatten(2) # [B, embed_dim, H', W'] -> [B, embed_dim, num_patches] x = x.transpose(1, 2) # [B, embed_dim, num_patches] -> [B, num_patches, embed_dim] return x # 示例使用 image_size = 224 patch_size = 16 in_channels = 3 embed_dim = 768 patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim) x = torch.randn(1, in_channels, image_size, image_size) output = patch_embed(x) print(output.shape) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

墨顿

唵嘛呢叭咪吽

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值