torch.argmax函数dim=None应用于高维矩阵的理解

这篇博客探讨了在深度学习中使用PyTorch的argmax函数时,遇到dim=None的情况。作者通过实例展示了如何解析返回的索引以确定高维矩阵中最大值的具体位置。通过代码解析,解释了如何在未指定维度的情况下,从一个四维矩阵中找到最大值的索引,并将其转化为对应的三维坐标。文章强调了在不同维度间的关系,并提供了个人的理解,适用于对PyTorch和高维数据操作感兴趣的读者。

在深度学习模型中经常会用到torch.argmax函数,网上对它的讲解多是针对指定dim参数的情况。但是最近遇到了一个dim=None的情况,不是很理解,查了半天也没找到相关的解释。自己写了个例子试了一下,大概理解了,记录一下,做个备忘。

import torch

# 随机生成一个4维矩阵
a = torch.rand((2, 2, 3, 4))
# 获取a矩阵的形状
b, d, w, h = a.shape
print(a)
# 获取a矩阵中最大值的索引
index = torch.argmax(a)
print(index)
# 获取最大值的z索引,z对应shape中的d
z = int(index // w // h)
index -= z * w * h
# 获取最大值的x索引,x对应shape中的w
x = int(index // h)
index -= x * h
# 获取最大值的y索引,y对应shape中的h
y = int(index)
# 打印z,x,y
print(z, x, y)

运行结果:

tensor([[[[0.2936, 0.2863, 0.5059, 0.4042],
          [0.9422, 0.4629, 0.8336, 0.0168],
          [0.9002, 0.8628, 0.3787, 0.9284]],

         [[0.5688, 0.2993, 0.3334, 0.9471],
          [0.4500, 0.1274, 0.1956, 0.6806],
          [0.8735, 0.5767, 0.8293, 0.3108]]],


        [[[0.3033, 0.8770, 0.2276, 0.4150],
          [0.2653, 0.9783, 0.2614, 0.9467],
          [0.4042, 0.8505, 0.0225, 0.4542]],

         [[0.8606, 0.3494, 0.1172, 0.4817],
          [0.1268, 0.2600, 0.1153, 0.6345],
          [0.7228, 0.9589, 0.2653, 0.5185]]]])
tensor(29)
2 1 1

首先生成了一个随机矩阵a,从矩阵a来看,最大值为0.9783.
其索引为(2, 1, 1)。torch.argmax(a)返回的index只有一个值29。下面理解一下29是怎么来的,首先对于矩阵a,其shape为(2, 2, 3, 4), 29=[2×(3×4)]+1×4+1×1=24+4+1
。可以理解为,对于d维度,必须满足w和h维度后,才有d维度;对于w维度,必须满足h维度后才有w维度;对于h维度,h维度就是单个的数。

这么讲可能讲不太清楚,可结合256 = 2×100+5×10+6×1来理解,其中100=10×10,也就是256=2×(10×10)+5×10+6×1。所以,通过上述代码,可在未指定dim时求解出整个矩阵最大值的具体索引。

这是对于不指定dim时的高维矩阵的情况,在指定dim时,网上有很多讲解的博客,这里就不记录了,想了解可以看这一篇博客

上述内容是个人理解,如有不对,欢迎指正!

class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # 相对位置偏置表 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 获取每个token的相对位置索引 coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None, pfa_values=None, pfa_indices=None): b_, n, c = x.shape qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) # 处理相对位置偏置 relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) # **修复PFA操作的维度问题** if pfa_values is not None and len(pfa_values) > 0: # 获取当前层的shift值(简化处理) shift = 0 if shift < len(pfa_values): try: pfa_val = pfa_values[shift] # 确保pfa_val的维度与attn兼容 if pfa_val.dim() != attn.dim(): # 调整维度数量 while pfa_val.dim() < attn.dim(): pfa_val = pfa_val.unsqueeze(0) while pfa_val.dim() > attn.dim(): pfa_val = pfa_val.squeeze(0) # 调整各维度大小 target_shape = list(attn.shape) current_shape = list(pfa_val.shape) # 逐维度调整 for i in range(len(target_shape)): if i < len(current_shape): if current_shape[i] == 1 and target_shape[i] > 1: # 扩展维度 expand_dims = [1] * len(current_shape) expand_dims[i] = target_shape[i] pfa_val = pfa_val.expand(*expand_dims) elif current_shape[i] > target_shape[i]: # 截取维度 if i == len(current_shape) - 1: # 最后一维 pfa_val = pfa_val[..., :target_shape[i]] elif i == len(current_shape) - 2: # 倒数第二维 pfa_val = pfa_val[..., :target_shape[i], :] # 最终形状调整 if pfa_val.shape != attn.shape: # 尝试广播或重塑 try: pfa_val = pfa_val.expand_as(attn) except: # 如果无法广播,创建兼容的张量 pfa_val = torch.ones_like(attn) # 应用PFA attn = attn * pfa_val except Exception as e: print(f"PFA application warning: {e}, shapes: attn={attn.shape}, pfa_val shape attempted") # 继续执行,不应用PFA if mask is not None: nw = mask.shape[0] attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, n, n) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) # 处理SMM操作的维度问题 if pfa_indices is not None and len(pfa_indices) > 0: shift = 0 if shift < len(pfa_indices): try: smm_index = pfa_indices[shift] # 使用修复后的SMM_AmV x = SMM_AmV.apply(attn, v, smm_index) except Exception as e: print(f"SMM operation warning: {e}") # 回退到标准操作 x = (attn @ v).transpose(1, 2).reshape(b_, n, c) else: x = (attn @ v).transpose(1, 2).reshape(b_, n, c) else: x = (attn @ v).transpose(1, 2).reshape(b_, n, c) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, n): # calculate flops for 1 window with token length of n flops = 0 # qkv = self.qkv(x) flops += n * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * n * (self.dim // self.num_heads) * n # x = (attn @ v) flops += self.num_heads * n * n * (self.dim // self.num_heads) # x = self.proj(x) flops += n * self.dim * self.dim return flops class PFWindowAttention(nn.Module): r""" Progressive Focused Window based multi-head self attention (PF-W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, num_topk, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., layer_id=0): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.layer_id = layer_id self.topk = num_topk self.eps = 1e-8 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) self.memory_efficient = True # 启用内存优化模式 # self.max_topk = 256 # 限制topk大小避免内存爆炸 def forward(self, x, mask=None, pfa_values=None, pfa_indices=None): b_, n, c = x.shape qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # 维度对齐(解决CUDA错误的关键) k = k[..., :q.shape[-1]] q = q * self.scale attn = (q @ k.transpose(-2, -1)) # 数值稳定性处理 attn_max = attn.detach().max(dim=-1, keepdim=True).values attn = attn - attn_max # 相对位置偏置处理 relative_position_bias = self.relative_position_bias_table[...] # 省略具体实现 attn = attn + relative_position_bias.unsqueeze(0) # PFA处理 shift = 0 if pfa_values is not None and shift < len(pfa_values): pfa_val = pfa_values[shift] # 维度对齐 if pfa_val.shape[-2:] != attn.shape[-2:]: pfa_val = F.interpolate(pfa_val, size=attn.shape[-2:], mode='nearest') attn = attn * pfa_val # Softmax处理 attn = self.softmax(attn) attn = self.attn_drop(attn) # 稀疏矩阵乘法(核心修复) shift = 0 try: if pfa_indices is not None and shift < len(pfa_indices): smm_index = pfa_indices[shift] # 索引安全处理 smm_index = torch.clamp(smm_index, 0, v.size(2) - 1) x = SMM_AmV_Safe.apply(attn, v, smm_index) else: # 标准矩阵乘法 x = (attn @ v).transpose(1, 2).reshape(b_, n, c) except RuntimeError as e: # CUDA错误回退机制 print(f"CUDA错误回退: {e}") x = (attn @ v).transpose(1, 2).reshape(b_, n, c) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, n): # calculate flops for 1 window with token length of n flops = 0 # qkv = self.qkv(x) flops += n * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * n * (self.dim // self.num_heads) * n # x = (attn @ v) flops += self.num_heads * n * n * (self.dim // self.num_heads) # x = self.proj(x) flops += n * self.dim * self.dim return flops
09-03
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值