手写系列——基于CrossAttention结构的网络

参考:

CrossFormer实战:使用CrossFormer实现图像分类任务(一)_crossformer dpb-优快云博客

TPAMI 2024 | 真是天才!浙江大学提出跨尺度、长距离注意力Transformer,胜任多项视觉任务! 

TPAMI 2024 | CrossFormer++: 基于跨尺度注意力的多功能视觉Transformer 

代码:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

# 跨尺度注意力模块(文献[4][5])
class CrossScaleAttention(nn.Module):
    def __init__(self, in_dim, num_heads=4):
        super().__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim//2, 1)
        self.key_conv = nn.Conv2d(in_dim, in_dim//2, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.softmax = nn.Softmax(dim=-1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, H, W = x.size()
        
        # 生成多尺度特征(文献[5]CEL思想)
        query = self.query_conv(x).view(batch_size, -1, H*W).permute(0,2,1)  # [B, N, C']
        key = self.key_conv(x).view(batch_size, -1, H*W)  # [B, C', N]
        value = self.value_conv(x).view(batch_size, -1, H*W)  # [B, C, N]
        
        # 注意力计算(文献[1][7])
        energy = torch.bmm(query, key)  # [B, N, N]
        attention = self.softmax(energy)
        out = torch.bmm(value, attention.permute(0,2,1))
        out = out.view(batch_size, C, H, W)
        
        return self.gamma*out + x  # 残差连接

# 完整分类模型(文献[4][5][7])
class CrossAttentionClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # 特征提取器
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            CrossScaleAttention(64),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            CrossScaleAttention(128)
        )
        
        # 分类头(文献[2]特征融合思想)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# 训练配置
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)

# 初始化模型
model = CrossAttentionClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# 训练循环(文献[4]训练技巧)
for epoch in range(50):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/50], Loss: {loss.item():.4f}')

训练结果:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值