<think>
我们要求将BroadcastAttention替换为Windowing Multi-Head Attention。Windowing Multi-Head Attention通常是指将输入分割为多个窗口(或块),在每个窗口内进行多头注意力计算,以减少计算复杂度并捕捉局部依赖。
注意:原BroadcastAttention设计用于一维时序信号(Conv1d),因此Windowing Multi-Head Attention也需要适应一维输入。
步骤:
1. 我们将实现一个名为WindowMultiHeadAttention的模块,它应该能够处理形状为[B, C, L]的输入(其中L是序列长度)。
2. 由于是多头注意力,我们首先需要将输入通过线性变换拆分为多个头。
3. 然后,将序列划分为多个窗口(非重叠或重叠),在每个窗口内计算自注意力。
4. 最后,合并所有窗口的输出并经过另一个线性变换。
设计要点:
- 窗口大小(window_size)和步长(stride)可以设定,通常窗口大小固定,步长等于窗口大小(非重叠)或小于(重叠)。
- 为了简化,我们使用非重叠窗口。如果序列长度不能整除窗口大小,可以填充或截断,这里选择填充(padding)至整数倍。
由于原模型处理的是时序信号,且经过多次下采样,序列长度会逐渐减小,因此窗口大小可以设为固定值或根据序列长度自适应。这里我们设为可配置参数。
另外,多头注意力的头数(num_heads)也需要设定。
我们按照标准的Multi-Head Self-Attention(MHSA)实现,但限制在每个窗口内。
注意:原始BroadcastAttention输出维度与输入相同,我们的新模块也应保持相同。
实现方案:
1. 输入x: [B, C, L] -> 首先通过LayerNorm(因为在BA_FFN_Block中,输入注意力前有LayerNorm,所以这里可以不用,但为了与标准Transformer一致,我们可以在模块内部做或不做?根据ViT等模型,通常先做LayerNorm,但原BA_FFN_Block已经有一个LayerNorm。因此,我们在WindowMultiHeadAttention内部不包含LayerNorm,仅做线性变换和窗口划分)。
2. 通过三个线性层(或一个线性层再拆分)得到Q、K、V,然后分割成多头。
3. 将序列划分为窗口,在每个窗口内计算注意力。
4. 合并窗口,然后通过输出线性层。
由于输入是[B, C, L],而标准注意力在序列维度(L)上计算,因此我们需要将通道C作为特征维度。这里,我们设定每个头的维度为head_dim = C // num_heads(要求可整除)。
具体步骤:
a) 对输入x(经过LayerNorm后的输入,但LayerNorm在BA_FFN_Block中已经应用)进行线性变换得到Q、K、V。
b) 将Q、K、V按头拆分:reshape -> [B, num_heads, L, head_dim]
c) 将序列分割成窗口,假设窗口大小为w,那么序列会被分成n_w = ceil(L/w)个窗口(可能需要填充)。
- 将Q、K、V reshape为 [B, num_heads, n_w, w, head_dim]
d) 在每个窗口内计算注意力:attention = softmax(Q @ K.transpose(-1,-2) / sqrt(head_dim)) @ V
-> 输出 [B, num_heads, n_w, w, head_dim]
e) 合并窗口:reshape -> [B, num_heads, n_w*w, head_dim] -> 然后去掉填充部分(如果有)得到 [B, num_heads, L, head_dim]
f) 合并多头:reshape -> [B, L, C] (注意:这里我们暂时将通道维放在最后,因为线性层通常作用在最后一维)
g) 通过输出线性层,并转置回 [B, C, L](因为后续处理需要通道在中间维度)
注意:由于原模型使用Conv1d,特征维度是通道C,而序列长度是L。在注意力中,我们通常将序列长度视为token数,每个token的维度是C。
但这里有一个问题:原输入是[B, C, L],而标准注意力要求[B, L, C]。因此,在注意力计算前,我们需要将x转置为[B, L, C],计算完后再转置回来。
因此,我们在WindowMultiHeadAttention模块中:
- 首先将输入从[B, C, L]转置为[B, L, C]
- 然后进行多头注意力的计算
- 最后输出再转置回[B, C, L]
另外,为了与原始BroadcastAttention的输入输出一致,我们还需要考虑投影层和dropout等。
我们设计如下:
参数:
- dim: 输入通道数(即C)
- num_heads: 多头注意力的头数
- window_size: 窗口大小
- qkv_bias: 线性层是否使用偏置
- attn_drop: 注意力权重的dropout率
- proj_drop: 输出投影的dropout率
步骤:
1. 转置输入: [B, C, L] -> [B, L, C]
2. 通过线性层生成qkv: [B, L, C] -> [B, L, 3*C] (然后拆分为q, k, v)
3. reshape q, k, v 为 [B, L, num_heads, head_dim] -> 调整维度为 [B, num_heads, L, head_dim]
4. 将序列分窗口:将q, k, v reshape为 [B, num_heads, n_w, window_size, head_dim]
5. 计算窗口内注意力(注意:每个窗口独立计算)
6. 合并窗口:reshape回 [B, num_heads, L, head_dim]
7. 合并多头:reshape为 [B, L, C]
8. 通过输出线性层和dropout
9. 转置输出为 [B, C, L]
注意:如果序列长度不是窗口大小的整数倍,我们需要填充。这里我们在分窗口前进行填充,并在合并窗口后移除填充。
为了简化,我们使用torch.nn.functional.pad进行填充。
但注意:原模型输入长度可能变化(如下采样后长度会变短),所以窗口大小可能需要根据实际情况设定(比如设定为固定值,如16或32)。在模型中,我们可以在初始化时根据输入长度来调整?但通常固定窗口大小。
我们实现一个填充函数和一个反填充函数(在内部处理)。
代码实现:
由于代码较长,我们只替换BroadcastAttention为WindowMultiHeadAttention,并保持BA_FFN_Block的其他部分不变。
注意:在BA_FFN_Block中,我们使用了LayerNorm(转置操作),所以输入WindowMultiHeadAttention的是[B, C, L](但经过转置归一化后,又转置回来,所以实际上输入WindowMultiHeadAttention的x是[B, C, L])。因此,我们直接处理[B, C, L]并输出[B, C, L]。
下面实现WindowMultiHeadAttention模块:
</think>
### 修改后的 Windowing Multi-Head Attention 实现
以下是替换 `BroadcastAttention` 为 Windowing Multi-Head Attention 的完整代码实现:
```python
class WindowMultiHeadAttention(nn.Module):
def __init__(self,
dim,
num_heads=4,
window_size=8,
proj_drop=0.,
attn_drop=0.,
qkv_bias=True):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.head_dim = dim // num_heads
assert dim % num_heads == 0, "dim must be divisible by num_heads"
# QKV projection
self.qkv_proj = nn.Conv1d(dim, 3 * dim, kernel_size=1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv1d(dim, dim, kernel_size=1, bias=qkv_bias)
self.proj_drop = nn.Dropout(proj_drop)
def window_partition(self, x):
"""
将输入序列划分为多个窗口
[B, C, L] -> [B * num_windows, num_heads, window_size, head_dim]
"""
B, C, L = x.shape
# 填充序列长度使其可被窗口整除
padding = (self.window_size - L % self.window_size) % self.window_size
x = F.pad(x, (0, padding))
L_padded = L + padding
# 划分窗口 [B, C, L] -> [B * num_windows, C, window_size]
x = x.unfold(2, self.window_size, self.window_size)
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(-1, C, self.window_size)
# 多头分割 [B * num_windows, C, window_size] -> [B * num_windows, num_heads, window_size, head_dim]
x = x.view(-1, self.num_heads, self.head_dim, self.window_size)
x = x.permute(0, 1, 3, 2).contiguous()
return x, padding
def window_reverse(self, x, L, padding):
"""
将窗口合并回完整序列
[B * num_windows, num_heads, window_size, head_dim] -> [B, C, L]
"""
B_numwindows, num_heads, W, head_dim = x.shape
B = B_numwindows // (L // self.window_size)
C = num_heads * head_dim
# 转换形状 [B * num_windows, num_heads, window_size, head_dim]
# -> [B, num_windows, num_heads, window_size, head_dim]
x = x.view(B, -1, num_heads, W, head_dim)
x = x.permute(0, 2, 4, 1, 3).contiguous()
x = x.view(B, C, -1)
# 去除填充
if padding > 0:
x = x[:, :, :-padding] if padding != self.window_size else x[:, :, :L]
return x
def forward(self, x):
B, C, L_orig = x.shape
# 生成QKV [B, C, L] -> [B, 3C, L]
qkv = self.qkv_proj(x)
q, k, v = torch.split(qkv, self.dim, dim=1)
# 窗口划分
q, padding = self.window_partition(q) # [B * nW, num_heads, W, head_dim]
k, _ = self.window_partition(k)
v, _ = self.window_partition(v)
# 计算注意力分数
attn_scores = torch.matmul(q, k.transpose(-2, -1)) # [B * nW, num_heads, W, W]
attn_scores = attn_scores / math.sqrt(self.head_dim)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.attn_drop(attn_probs)
# 注意力加权
context = torch.matmul(attn_probs, v) # [B * nW, num_heads, W, head_dim]
# 还原窗口
context = self.window_reverse(context, L_orig, padding) # [B, C, L]
# 输出投影
out = self.proj(context)
out = self.proj_drop(out)
return out
```
### 集成到 BA_FFN_Block 中
```python
class BA_FFN_Block(nn.Module):
def __init__(self,
dim,
ffn_dim,
drop=0.,
attn_drop=0.,
num_heads=4, # 新增参数
window_size=8 # 新增参数
):
super().__init__()
self.norm1 = LayerNorm(dim)
self.add1 = Add()
# 替换为WindowMultiHeadAttention
self.attn = WindowMultiHeadAttention(dim=dim,
num_heads=num_heads,
window_size=window_size,
attn_drop=attn_drop,
proj_drop=drop)
# 以下部分保持不变
self.norm2 = LayerNorm(dim)
self.add2 = Add()
self.ffn = nn.Sequential(
nn.Conv1d(dim, ffn_dim, 1, 1, bias=True),
nn.GELU(),
nn.Dropout(p=drop),
nn.Conv1d(ffn_dim, dim, 1, 1, bias=True),
nn.Dropout(p=drop)
)
# forward方法保持不变
```
### 关键改进说明
1. **窗口分割机制**:
- 使用 `unfold` 操作将序列划分为固定大小的窗口
- 支持序列长度不能被窗口整除的情况(自动填充处理)
- 窗口内独立计算注意力,显著降低计算复杂度
2. **多头注意力设计**:
- 支持可配置的注意力头数(`num_heads`)
- 每个头处理部分特征维度(`head_dim = dim // num_heads`)
- 并行计算不同表示子空间的注意力
3. **输入输出兼容性**:
- 保持输入输出维度与原始BroadcastAttention一致
- 无缝集成到现有BA_FFN_Block结构中
- 新增关键参数(`num_heads`和`window_size`)便于实验调节
### 对比实验建议
1. **窗口大小影响**:
```python
# 实验不同窗口尺寸
window_sizes = [4, 8, 16, 32]
for ws in window_sizes:
model = Liconvformer(..., window_size=ws)
```
2. **注意力头数对比**:
```python
# 测试不同头数配置
head_configs = [2, 4, 8]
for heads in head_configs:
model = Liconvformer(..., num_heads=heads)
```
3. **长序列效率测试**:
```python
# 对比不同序列长度的计算效率
seq_lengths = [256, 512, 1024, 2048]
for L in seq_lengths:
input = torch.randn(16, in_channel, L)
measure_runtime(model, input)
```