突破视觉特征提取瓶颈:VGGT中PatchEmbed与Position Encoding的创新实践
【免费下载链接】vggt VGGT Visual Geometry Grounded Transformer 项目地址: https://gitcode.com/gh_mirrors/vg/vggt
你是否还在为图像特征提取的效率与精度难以兼顾而困扰?是否好奇先进的视觉Transformer如何将复杂图像转化为计算机可理解的语言?本文将深入解析VGGT(Visual Geometry Grounded Transformer)中的两大核心模块——PatchEmbed与Position Encoding,带你掌握视觉特征提取的关键技术,看完即能理解图像到向量的转换奥秘。
视觉特征提取的核心挑战
在计算机视觉领域,将二维图像转化为一维特征向量是所有视觉任务的基础。传统卷积神经网络(CNN)通过滑动窗口提取局部特征,但在处理长距离依赖关系时存在固有局限。而视觉Transformer(ViT)的出现,通过将图像分割为补丁(Patch)并添加位置编码(Position Encoding),成功实现了全局特征建模。
VGGT作为新一代视觉几何Transformer,其特征提取模块在继承ViT优势的基础上进行了针对性优化。项目核心代码结构显示,特征提取相关实现集中在vggt/layers/目录下,其中patch_embed.py和rope.py分别实现了补丁嵌入和位置编码功能,并在vision_transformer.py中完成整合。
PatchEmbed:图像到补丁的高效转化
PatchEmbed模块负责将输入图像分割为固定大小的补丁,并将每个补丁线性投影为特征向量。这一过程类似于将一幅画切割成若干小块,每块用一个向量来描述。
核心原理与实现
VGGT的PatchEmbed实现位于vggt/layers/patch_embed.py,其核心代码如下:
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
上述代码通过以下步骤完成图像到补丁的转化:
- 图像尺寸处理:将输入图像尺寸和补丁尺寸转换为元组形式,确保兼容性
- 网格计算:计算补丁网格大小和总补丁数量
- 卷积投影:使用卷积层(kernel_size和stride均等于patch_size)将每个补丁投影到嵌入维度
- 归一化:可选的归一化层进一步优化特征分布
前向传播过程
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
前向传播过程中,输入图像(B, C, H, W)经过卷积层得到(B, embed_dim, H/patch_size, W/patch_size)的特征图,随后展平为(B, num_patches, embed_dim)的补丁序列。
实际效果展示
以项目示例图像为例,原始输入图像如examples/llff_fern/images/000.png所示:
经过PatchEmbed处理后,图像被分割为14×14(默认224×224输入,16×16补丁)的补丁网格,每个补丁转化为768维的特征向量,形成196×768的特征矩阵。
Position Encoding:空间位置信息的有效融入
仅仅将图像分割为补丁并投影为向量是不够的,因为Transformer本身不具备感知位置信息的能力。位置编码(Position Encoding)通过为每个补丁添加位置相关的信息,使模型能够理解补丁在原始图像中的空间位置关系。
2D Rotary Position Embedding的创新
VGGT采用了创新的2D Rotary Position Embedding(RoPE),不同于传统的正弦余弦位置编码,RoPE通过旋转操作将位置信息融入特征向量,具有更好的平移不变性和长序列建模能力。相关实现位于vggt/layers/rope.py。
位置生成器
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid."""
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
PositionGetter类负责生成并缓存补丁的2D空间坐标,避免重复计算,提高效率。
2D RoPE核心实现
class RotaryPositionEmbedding2D(nn.Module):
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
2D RoPE将特征向量分为垂直和水平两部分,分别对y坐标和x坐标应用旋转操作,最后合并结果。这种处理方式能够更好地捕捉2D图像的空间几何关系。
位置编码效果对比
传统正弦余弦位置编码直接将位置信息加在特征向量上,而RoPE通过旋转操作融合位置信息,在处理大尺寸图像和长序列时表现更优。以下是两种编码方式在不同任务上的对比:
| 编码方式 | 计算复杂度 | 平移不变性 | 长序列性能 | 几何感知能力 |
|---|---|---|---|---|
| 正弦余弦编码 | 低 | 弱 | 一般 | 弱 |
| 2D RoPE | 中 | 强 | 好 | 强 |
VGGT采用的2D RoPE在保持计算效率的同时,显著提升了模型对图像几何结构的理解能力,这对于后续的视觉几何任务至关重要。
模块整合与应用
在VGGT中,PatchEmbed和Position Encoding并非孤立存在,而是在视觉Transformer的整体架构中协同工作。vision_transformer.py中的DinoVisionTransformer类完成了这些模块的整合:
class DinoVisionTransformer(nn.Module):
def __init__(...):
# 初始化PatchEmbed
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size,
in_chans=in_chans, embed_dim=embed_dim)
# 初始化位置编码相关参数
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x) # 应用PatchEmbed
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h) # 添加位置编码
return x
上述代码展示了VGGT特征提取的完整流程:首先通过PatchEmbed将图像转化为补丁特征序列,然后添加位置编码,最终形成同时包含内容信息和位置信息的特征序列,输入后续Transformer编码器。
实际应用与效果展示
VGGT的特征提取模块在多种视觉任务中表现出色。以场景重建任务为例,输入一系列不同角度的图像,如examples/kitchen/images/目录下的厨房场景图像序列:
经过VGGT的特征提取和后续处理,能够重建出精确的3D场景结构。这一过程中,PatchEmbed对图像细节的精确捕捉和2D RoPE对空间位置的准确建模起到了关键作用。
总结与展望
VGGT的PatchEmbed和Position Encoding模块通过创新设计,有效解决了视觉Transformer在特征提取阶段的核心挑战。PatchEmbed实现了图像到特征向量的高效转化,而2D RoPE则为模型提供了强大的空间位置感知能力。两者的协同工作,为后续的视觉几何任务奠定了坚实基础。
随着计算机视觉技术的发展,特征提取模块将朝着更高效、更鲁棒的方向演进。VGGT作为视觉几何Transformer的代表,其特征提取方案为相关研究提供了重要参考。想要深入了解更多细节,可以查阅项目完整代码,特别是vggt/layers/目录下的相关实现。
掌握VGGT的特征提取技术,将帮助你在计算机视觉任务中取得更好的效果。立即动手实践,体验视觉Transformer的强大能力!
【免费下载链接】vggt VGGT Visual Geometry Grounded Transformer 项目地址: https://gitcode.com/gh_mirrors/vg/vggt
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考





