Shunted Self Attention via Multi Scale Token Aggregatio | CVPR 2022

该博客详细介绍了OliverRensu在Shunted Transformer上的研究,这是一种针对大型预训练模型的压缩技术。作者讨论了如何通过结构化剪枝和量化等方法有效减小模型大小,同时保持高性能。博客还涵盖了在实际应用中实施这些优化策略的步骤和经验。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

https://github.com/OliverRensu/Shunted-Transformer.

 

 

 

 

### 关于Shunted Transformer的Python实现与解析 #### Python代码实现 Shunted Transformer是一种改进型视觉变换器架构,在计算机视觉任务中表现出色。下面展示了一个简化版的Shunted Self-Attention机制的核心部分,该机制允许不同尺度下的特征聚合。 ```python import torch from torch import nn class ShuntedSelfAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads." head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.num_heads = num_heads self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) self.kv_proj = nn.Conv2d(dim, dim * 2, kernel_size=(1, 1), stride=(1, 1)) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, H, W, C = x.shape N = H * W # Query projection with shape transformation from (B,H,W,C) to (B,N,num_heads,C//num_heads) q = self.q_proj(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # Key and Value projections using convolution; reshaping into multi-head format. kv = self.kv_proj(x.permute(0, 3, 1, 2)).reshape(B, 2*C, N).transpose(-2, -1) k, v = kv.chunk(2, dim=-1) k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # Compute attention scores between queries and keys across all heads. attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # Apply the weighted sum of values according to computed attentions. out = (attn @ v).transpose(1, 2).reshape(B, N, C) # Final linear layer followed by dropout as per original paper design. out = self.proj(out) out = self.proj_drop(out) return out.reshape(B, H, W, C) # Example usage: model = ShuntedSelfAttention(dim=768, num_heads=12) input_tensor = torch.randn((1, 14, 14, 768)) # Batch size 1, image patch grid 14x14, embedding dimension 768 output = model(input_tensor) print(output.shape) # Should output: torch.Size([1, 14, 14, 768]) ``` 此段代码定义了`ShuntedSelfAttention`类,实现了论文中的关键组件之一——多尺度token聚集[^4]。通过这种方式,模型能够在保持较低参数量的同时提高性能表现。 对于MATLAB仿真方面,由于Shunted Transformer主要基于PyTorch框架开发并优化用于深度学习应用,因此官方并没有提供相应的MATLAB版本。不过有兴趣的研究者可以根据上述Python代码逻辑自行转换到MATLAB环境中去实现相似功能[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值