参考:
CrossFormer实战:使用CrossFormer实现图像分类任务(一)_crossformer dpb-优快云博客
TPAMI 2024 | 真是天才!浙江大学提出跨尺度、长距离注意力Transformer,胜任多项视觉任务!
TPAMI 2024 | CrossFormer++: 基于跨尺度注意力的多功能视觉Transformer
代码:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
# 跨尺度注意力模块(文献[4][5])
class CrossScaleAttention(nn.Module):
def __init__(self, in_dim, num_heads=4):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, in_dim//2, 1)
self.key_conv = nn.Conv2d(in_dim, in_dim//2, 1)
self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, H, W = x.size()
# 生成多尺度特征(文献[5]CEL思想)
query = self.query_conv(x).view(batch_size, -1, H*W).permute(0,2,1) # [B, N, C']
key = self.key_conv(x).view(batch_size, -1, H*W) # [B, C', N]
value = self.value_conv(x).view(batch_size, -1, H*W) # [B, C, N]
# 注意力计算(文献[1][7])
energy = torch.bmm(query, key) # [B, N, N]
attention = self.softmax(energy)
out = torch.bmm(value, attention.permute(0,2,1))
out = out.view(batch_size, C, H, W)
return self.gamma*out + x # 残差连接
# 完整分类模型(文献[4][5][7])
class CrossAttentionClassifier(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# 特征提取器
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
CrossScaleAttention(64),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
CrossScaleAttention(128)
)
# 分类头(文献[2]特征融合思想)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# 训练配置
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
# 初始化模型
model = CrossAttentionClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# 训练循环(文献[4]训练技巧)
for epoch in range(50):
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪
optimizer.step()
print(f'Epoch [{epoch+1}/50], Loss: {loss.item():.4f}')
训练结果: