class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 相对位置偏置表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
# 获取每个token的相对位置索引
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None, pfa_values=None, pfa_indices=None):
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 处理相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
# **修复PFA操作的维度问题**
if pfa_values is not None and len(pfa_values) > 0:
# 获取当前层的shift值(简化处理)
shift = 0
if shift < len(pfa_values):
try:
pfa_val = pfa_values[shift]
# 确保pfa_val的维度与attn兼容
if pfa_val.dim() != attn.dim():
# 调整维度数量
while pfa_val.dim() < attn.dim():
pfa_val = pfa_val.unsqueeze(0)
while pfa_val.dim() > attn.dim():
pfa_val = pfa_val.squeeze(0)
# 调整各维度大小
target_shape = list(attn.shape)
current_shape = list(pfa_val.shape)
# 逐维度调整
for i in range(len(target_shape)):
if i < len(current_shape):
if current_shape[i] == 1 and target_shape[i] > 1:
# 扩展维度
expand_dims = [1] * len(current_shape)
expand_dims[i] = target_shape[i]
pfa_val = pfa_val.expand(*expand_dims)
elif current_shape[i] > target_shape[i]:
# 截取维度
if i == len(current_shape) - 1: # 最后一维
pfa_val = pfa_val[..., :target_shape[i]]
elif i == len(current_shape) - 2: # 倒数第二维
pfa_val = pfa_val[..., :target_shape[i], :]
# 最终形状调整
if pfa_val.shape != attn.shape:
# 尝试广播或重塑
try:
pfa_val = pfa_val.expand_as(attn)
except:
# 如果无法广播,创建兼容的张量
pfa_val = torch.ones_like(attn)
# 应用PFA
attn = attn * pfa_val
except Exception as e:
print(f"PFA application warning: {e}, shapes: attn={attn.shape}, pfa_val shape attempted")
# 继续执行,不应用PFA
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# 处理SMM操作的维度问题
if pfa_indices is not None and len(pfa_indices) > 0:
shift = 0
if shift < len(pfa_indices):
try:
smm_index = pfa_indices[shift]
# 使用修复后的SMM_AmV
x = SMM_AmV.apply(attn, v, smm_index)
except Exception as e:
print(f"SMM operation warning: {e}")
# 回退到标准操作
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
else:
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
else:
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, n):
# calculate flops for 1 window with token length of n
flops = 0
# qkv = self.qkv(x)
flops += n * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * n * (self.dim // self.num_heads) * n
# x = (attn @ v)
flops += self.num_heads * n * n * (self.dim // self.num_heads)
# x = self.proj(x)
flops += n * self.dim * self.dim
return flops
class PFWindowAttention(nn.Module):
r""" Progressive Focused Window based multi-head self attention (PF-W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, num_topk, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
layer_id=0):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.layer_id = layer_id
self.topk = num_topk
self.eps = 1e-8
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
self.memory_efficient = True # 启用内存优化模式
# self.max_topk = 256 # 限制topk大小避免内存爆炸
def forward(self, x, mask=None, pfa_values=None, pfa_indices=None):
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# 维度对齐(解决CUDA错误的关键)
k = k[..., :q.shape[-1]]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 数值稳定性处理
attn_max = attn.detach().max(dim=-1, keepdim=True).values
attn = attn - attn_max
# 相对位置偏置处理
relative_position_bias = self.relative_position_bias_table[...] # 省略具体实现
attn = attn + relative_position_bias.unsqueeze(0)
# PFA处理
shift = 0
if pfa_values is not None and shift < len(pfa_values):
pfa_val = pfa_values[shift]
# 维度对齐
if pfa_val.shape[-2:] != attn.shape[-2:]:
pfa_val = F.interpolate(pfa_val, size=attn.shape[-2:], mode='nearest')
attn = attn * pfa_val
# Softmax处理
attn = self.softmax(attn)
attn = self.attn_drop(attn)
# 稀疏矩阵乘法(核心修复)
shift = 0
try:
if pfa_indices is not None and shift < len(pfa_indices):
smm_index = pfa_indices[shift]
# 索引安全处理
smm_index = torch.clamp(smm_index, 0, v.size(2) - 1)
x = SMM_AmV_Safe.apply(attn, v, smm_index)
else:
# 标准矩阵乘法
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
except RuntimeError as e:
# CUDA错误回退机制
print(f"CUDA错误回退: {e}")
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, n):
# calculate flops for 1 window with token length of n
flops = 0
# qkv = self.qkv(x)
flops += n * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * n * (self.dim // self.num_heads) * n
# x = (attn @ v)
flops += self.num_heads * n * n * (self.dim // self.num_heads)
# x = self.proj(x)
flops += n * self.dim * self.dim
return flops