注意力计算的形状尺寸记录

前言

计算机视觉的学习者或多或少总要接触注意力机制,然而我在入门的时候对于QKV的计算一直有一股迷雾,因此借着这次的机会标注了一下每一层计算的尺寸。

初学者首先需要了解一定的矩阵计算的知识。

图例

代码分析

class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        # 可学习的温度参数 [num_heads, 1, 1]
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, dtype=torch.float32), requires_grad=True)

        # QKV生成层
        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)  # 输入输出形状不变(仅改变通道数)
        # 深度可分离卷积(每组处理1个通道)
        self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
        # 输出投影层
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b, c, h, w = x.shape  # 输入形状 [B, C, H, W]

        # 生成QKV三元组
        qkv = self.qkv(x)  # [B, 3C, H, W](通道数变为3倍)
        qkv = self.qkv_dwconv(qkv)  # [B, 3C, H, W](空间维度保持不变)

        # 拆分Q/K/V
        q, k, v = qkv.chunk(3, dim=1)  # 每个形状 [B, C, H, W]

        # 重塑为多头格式
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        # 现在 q 形状 [B, num_heads, C/num_heads, H*W]

        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        # k 形状 [B, num_heads, C/num_heads, H*W]

        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        # v 形状 [B, num_heads, C/num_heads, H*W]

        # 归一化处理(沿最后一个维度)
        q = F.normalize(q, dim=-1)  # 形状保持 [B, num_heads, C/num_heads, H*W]
        k = F.normalize(k, dim=-1)  # 形状保持 [B, num_heads, C/num_heads, H*W]

        # 计算注意力矩阵
        attn = (q @ k.transpose(-2, -1)) * self.temperature  # 矩阵乘法后形状 [B, num_heads, H*W, H*W]
        attn = attn.softmax(dim=-1)  # 保持形状 [B, num_heads, H*W, H*W]

        # 注意力加权聚合
        out = (attn @ v)  # [B, num_heads, H*W, H*W] @ [B, num_heads, C/num_heads, H*W]
        # 结果形状 [B, num_heads, H*W, H*W]
   

        # 重塑回原始格式
        out = rearrange(out, 'b head (h w) c -> b (head c) h w', h=h, w=w)
        # 输出形状 [B, C, H, W]

        # 最终投影
        out = self.project_out(out)  # 保持形状 [B, C, H, W]
        return out

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值