Transformer中,MHSA(multi-head self-attention)中为什么要分多个Head?

本文解析了多头注意力机制如何增强模型的表达学习能力。通过将输入数据拆分为多个Head,使得每个Head可以独立进行Attention计算,从而学习到不同的特征。这种机制不仅提高了模型的学习能力,而且计算复杂度保持不变。
AI助手已提取文章相关产品:

结论:模型的表达学习能力增强了

输入到MHSA中的数据的shape应该为B × L × Embedding,B是Batch,L是序列长度
而在MHSA中,数据的shape会被拆分为多个Head,所以shape会进一步变为:
B × L × Head × Little_Embedding

以Transformer为例,原始论文中Embedding为512,Head数为8,所以shape在进入MHSA中时,会变为:
B × L × 8 × 64

如果不分头,相当于对512*512的矩阵进行Attention计算;
而如果分头了,相当于8个头中,每个头彼此独立进行Attention计算,不同头学习到的特征也可能是不同的,相当于增强了模型的表达学习能力。
并且,8 次的 64 × 64,和 1 次的 512 × 512,两者的计算复杂度是一致的,并没有造成额外的计算开销。

您可能感兴趣的与本文相关内容

<think> 我们要求将BroadcastAttention替换为Windowing Multi-Head Attention。Windowing Multi-Head Attention通常是指将输入割为多个窗口(或块),在每个窗口内进行多头注意力计算,以减少计算复杂度并捕捉局部依赖。 注意:原BroadcastAttention设计用于一维时序信号(Conv1d),因此Windowing Multi-Head Attention也需要适应一维输入。 步骤: 1. 我们将实现一个名为WindowMultiHeadAttention的模块,它应该能够处理形状为[B, C, L]的输入(其中L是序列长度)。 2. 由于是多头注意力,我们首先需要将输入通过线性变换拆多个头。 3. 然后,将序列划多个窗口(非重叠或重叠),在每个窗口内计算自注意力。 4. 最后,合并所有窗口的输出并经过另一个线性变换。 设计要点: - 窗口大小(window_size)和步长(stride)可以设定,通常窗口大小固定,步长等于窗口大小(非重叠)或小于(重叠)。 - 为了简化,我们使用非重叠窗口。如果序列长度不能整除窗口大小,可以填充或截断,这里选择填充(padding)至整数倍。 由于原模型处理的是时序信号,且经过多次下采样,序列长度会逐渐减小,因此窗口大小可以设为固定值或根据序列长度自适应。这里我们设为可配置参数。 另外,多头注意力的头数(num_heads)也需要设定。 我们按照标准的Multi-Head Self-Attention(MHSA)实现,但限制在每个窗口内。 注意:原始BroadcastAttention输出维度与输入相同,我们的新模块也应保持相同。 实现方案: 1. 输入x: [B, C, L] -> 首先通过LayerNorm(因为在BA_FFN_Block中,输入注意力前有LayerNorm,所以这里可以不用,但为了与标准Transformer一致,我们可以在模块内部做或不做?根据ViT等模型,通常先做LayerNorm,但原BA_FFN_Block已经有一个LayerNorm。因此,我们在WindowMultiHeadAttention内部不包含LayerNorm,仅做线性变换和窗口划)。 2. 通过三个线性层(或一个线性层再拆)得到Q、K、V,然后割成多头。 3. 将序列划为窗口,在每个窗口内计算注意力。 4. 合并窗口,然后通过输出线性层。 由于输入是[B, C, L],而标准注意力在序列维度(L)上计算,因此我们需要将通道C作为特征维度。这里,我们设定每个头的维度为head_dim = C // num_heads(要求可整除)。 具体步骤: a) 对输入x(经过LayerNorm后的输入,但LayerNorm在BA_FFN_Block中已经应用)进行线性变换得到Q、K、V。 b) 将Q、K、V按头拆:reshape -> [B, num_heads, L, head_dim] c) 将序列割成窗口,假设窗口大小为w,那么序列会被成n_w = ceil(L/w)个窗口(可能需要填充)。 - 将Q、K、V reshape为 [B, num_heads, n_w, w, head_dim] d) 在每个窗口内计算注意力:attention = softmax(Q @ K.transpose(-1,-2) / sqrt(head_dim)) @ V -> 输出 [B, num_heads, n_w, w, head_dim] e) 合并窗口:reshape -> [B, num_heads, n_w*w, head_dim] -> 然后去掉填充部(如果有)得到 [B, num_heads, L, head_dim] f) 合并多头:reshape -> [B, L, C] (注意:这里我们暂时将通道维放在最后,因为线性层通常作用在最后一维) g) 通过输出线性层,并转置回 [B, C, L](因为后续处理需要通道在中间维度) 注意:由于原模型使用Conv1d,特征维度是通道C,而序列长度是L。在注意力中,我们通常将序列长度视为token数,每个token的维度是C。 但这里有一个问题:原输入是[B, C, L],而标准注意力要求[B, L, C]。因此,在注意力计算前,我们需要将x转置为[B, L, C],计算完后再转置回来。 因此,我们在WindowMultiHeadAttention模块中: - 首先将输入从[B, C, L]转置为[B, L, C] - 然后进行多头注意力的计算 - 最后输出再转置回[B, C, L] 另外,为了与原始BroadcastAttention的输入输出一致,我们还需要考虑投影层和dropout等。 我们设计如下: 参数: - dim: 输入通道数(即C) - num_heads: 多头注意力的头数 - window_size: 窗口大小 - qkv_bias: 线性层是否使用偏置 - attn_drop: 注意力权重的dropout率 - proj_drop: 输出投影的dropout率 步骤: 1. 转置输入: [B, C, L] -> [B, L, C] 2. 通过线性层生成qkv: [B, L, C] -> [B, L, 3*C] (然后拆为q, k, v) 3. reshape q, k, v 为 [B, L, num_heads, head_dim] -> 调整维度为 [B, num_heads, L, head_dim] 4. 将序列窗口:将q, k, v reshape为 [B, num_heads, n_w, window_size, head_dim] 5. 计算窗口内注意力(注意:每个窗口独立计算) 6. 合并窗口:reshape回 [B, num_heads, L, head_dim] 7. 合并多头:reshape为 [B, L, C] 8. 通过输出线性层和dropout 9. 转置输出为 [B, C, L] 注意:如果序列长度不是窗口大小的整数倍,我们需要填充。这里我们在窗口前进行填充,并在合并窗口后移除填充。 为了简化,我们使用torch.nn.functional.pad进行填充。 但注意:原模型输入长度可能变化(如下采样后长度会变短),所以窗口大小可能需要根据实际情况设定(比如设定为固定值,如16或32)。在模型中,我们可以在初始化时根据输入长度来调整?但通常固定窗口大小。 我们实现一个填充函数和一个反填充函数(在内部处理)。 代码实现: 由于代码较长,我们只替换BroadcastAttention为WindowMultiHeadAttention,并保持BA_FFN_Block的其他部不变。 注意:在BA_FFN_Block中,我们使用了LayerNorm(转置操作),所以输入WindowMultiHeadAttention的是[B, C, L](但经过转置归一化后,又转置回来,所以实际上输入WindowMultiHeadAttention的x是[B, C, L])。因此,我们直接处理[B, C, L]并输出[B, C, L]。 下面实现WindowMultiHeadAttention模块: </think> ### 修改后的 Windowing Multi-Head Attention 实现 以下是替换 `BroadcastAttention` 为 Windowing Multi-Head Attention 的完整代码实现: ```python class WindowMultiHeadAttention(nn.Module): def __init__(self, dim, num_heads=4, window_size=8, proj_drop=0., attn_drop=0., qkv_bias=True): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.head_dim = dim // num_heads assert dim % num_heads == 0, "dim must be divisible by num_heads" # QKV projection self.qkv_proj = nn.Conv1d(dim, 3 * dim, kernel_size=1, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=qkv_bias) self.proj_drop = nn.Dropout(proj_drop) def window_partition(self, x): """ 将输入序列划多个窗口 [B, C, L] -> [B * num_windows, num_heads, window_size, head_dim] """ B, C, L = x.shape # 填充序列长度使其可被窗口整除 padding = (self.window_size - L % self.window_size) % self.window_size x = F.pad(x, (0, padding)) L_padded = L + padding # 划窗口 [B, C, L] -> [B * num_windows, C, window_size] x = x.unfold(2, self.window_size, self.window_size) x = x.permute(0, 2, 1, 3).contiguous() x = x.view(-1, C, self.window_size) # 多头割 [B * num_windows, C, window_size] -> [B * num_windows, num_heads, window_size, head_dim] x = x.view(-1, self.num_heads, self.head_dim, self.window_size) x = x.permute(0, 1, 3, 2).contiguous() return x, padding def window_reverse(self, x, L, padding): """ 将窗口合并回完整序列 [B * num_windows, num_heads, window_size, head_dim] -> [B, C, L] """ B_numwindows, num_heads, W, head_dim = x.shape B = B_numwindows // (L // self.window_size) C = num_heads * head_dim # 转换形状 [B * num_windows, num_heads, window_size, head_dim] # -> [B, num_windows, num_heads, window_size, head_dim] x = x.view(B, -1, num_heads, W, head_dim) x = x.permute(0, 2, 4, 1, 3).contiguous() x = x.view(B, C, -1) # 去除填充 if padding > 0: x = x[:, :, :-padding] if padding != self.window_size else x[:, :, :L] return x def forward(self, x): B, C, L_orig = x.shape # 生成QKV [B, C, L] -> [B, 3C, L] qkv = self.qkv_proj(x) q, k, v = torch.split(qkv, self.dim, dim=1) # 窗口划 q, padding = self.window_partition(q) # [B * nW, num_heads, W, head_dim] k, _ = self.window_partition(k) v, _ = self.window_partition(v) # 计算注意力数 attn_scores = torch.matmul(q, k.transpose(-2, -1)) # [B * nW, num_heads, W, W] attn_scores = attn_scores / math.sqrt(self.head_dim) attn_probs = F.softmax(attn_scores, dim=-1) attn_probs = self.attn_drop(attn_probs) # 注意力加权 context = torch.matmul(attn_probs, v) # [B * nW, num_heads, W, head_dim] # 还原窗口 context = self.window_reverse(context, L_orig, padding) # [B, C, L] # 输出投影 out = self.proj(context) out = self.proj_drop(out) return out ``` ### 集成到 BA_FFN_Block 中 ```python class BA_FFN_Block(nn.Module): def __init__(self, dim, ffn_dim, drop=0., attn_drop=0., num_heads=4, # 新增参数 window_size=8 # 新增参数 ): super().__init__() self.norm1 = LayerNorm(dim) self.add1 = Add() # 替换为WindowMultiHeadAttention self.attn = WindowMultiHeadAttention(dim=dim, num_heads=num_heads, window_size=window_size, attn_drop=attn_drop, proj_drop=drop) # 以下部保持不变 self.norm2 = LayerNorm(dim) self.add2 = Add() self.ffn = nn.Sequential( nn.Conv1d(dim, ffn_dim, 1, 1, bias=True), nn.GELU(), nn.Dropout(p=drop), nn.Conv1d(ffn_dim, dim, 1, 1, bias=True), nn.Dropout(p=drop) ) # forward方法保持不变 ``` ### 关键改进说明 1. **窗口割机制**: - 使用 `unfold` 操作将序列划为固定大小的窗口 - 支持序列长度不能被窗口整除的情况(自动填充处理) - 窗口内独立计算注意力,显著降低计算复杂度 2. **多头注意力设计**: - 支持可配置的注意力头数(`num_heads`) - 每个头处理部特征维度(`head_dim = dim // num_heads`) - 并行计算不同表示子空间的注意力 3. **输入输出兼容性**: - 保持输入输出维度与原始BroadcastAttention一致 - 无缝集成到现有BA_FFN_Block结构中 - 新增关键参数(`num_heads`和`window_size`)便于实验调节 ### 对比实验建议 1. **窗口大小影响**: ```python # 实验不同窗口尺寸 window_sizes = [4, 8, 16, 32] for ws in window_sizes: model = Liconvformer(..., window_size=ws) ``` 2. **注意力头数对比**: ```python # 测试不同头数配置 head_configs = [2, 4, 8] for heads in head_configs: model = Liconvformer(..., num_heads=heads) ``` 3. **长序列效率测试**: ```python # 对比不同序列长度的计算效率 seq_lengths = [256, 512, 1024, 2048] for L in seq_lengths: input = torch.randn(16, in_channel, L) measure_runtime(model, input) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值