unsqueeze,squeeze,repeat,expand,expand_as应用示例及其详细说明

本文详细介绍了PyTorch中张量的squeeze、unsqueeze、expand和repeat操作,包括它们的功能、用法以及注意事项。通过示例代码解释了这些操作如何改变张量的形状,帮助理解在深度学习模型中如何处理数据维度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
box_a=torch.tensor([[1,2,3],[1,2,3]])
print(box_a[:, 2:].squeeze(0).shape)
print(box_a[:, 2:].shape)
# torch.Size([2, 1])
# torch.Size([2, 1])
print("#################################")
print(box_a[:, 2:].squeeze(1).shape)
print(box_a[:, 2:].shape)
'''squeeze 只能去掉维度为1的维度,如果目标维度不是1
比如是2,就不可以去掉'''
# torch.Size([2])
# torch.Size([2, 1])
print("#################################")
print(box_a[:, 2:].unsqueeze(0).shape)
print(box_a[:, 2:].shape)
'''unsqueeze 是在目标维度上增加一个维度,示例为在
第一个维度dim=0的位置添加一个维度'''
# torch.Size([1, 2, 1])
# torch.Size([2, 1])
print("#################################")
print(box_a[:, 2:].shape)
print(box_a[:, 2:].expand(2,1000).shape)
'''expand 只能扩展size为1的维度
参考 :https://blog.youkuaiyun.com/weixin_39504171/article/details/106090626
'''
# torch.Size([2, 1])
# torch.Size([2, 1000])
print("#################################")
print(box_a[:, 2:].shape)
print(box_a.shape)
print(box_a[:, 2:].expand_as(box_a).shape)
print(box_a[:, 2:])
print(box_a)
print(box_a[:, 2:].expand_as(box_a))
# torch.Size([2, 1])
# torch.Size([2, 3])
# torch.Size([2, 3])
# tensor([[3],
#         [3]])
# tensor([[1, 2, 3],
#         [1, 2, 3]])
# tensor([[3, 3, 3],
#         [3, 3, 3]])
'''expand_as 可以扩展到目标维度,只不过expand_as '''
# print(box_a[:, 2:].expand_as(box_a.unsqueeze(0)))
'''expand_as 比较复杂的示例 torch.Size([2, 1, 1]) 扩展到 torch.Size([4, 7, 1]) '''
# print(box_a[:, 2:].unsqueeze(1).expand_as(torch.rand(2,7,1)))
print(box_a[:, 2:].unsqueeze(1).shape)
# torch.Size([2, 1, 1])
print("#################################")
'''repeat 操作 '''
print(box_a.shape)
'''repeat (1)操作用来阔展维度 相当于 unsqueeze(0) '''
print(box_a.repeat(1,1,1,1).shape)
# RuntimeError: Number of dimensions of repeat dims can not be
# smaller than number of dimensions of tensor
'''复制后的最小维度数目不能小于原先的维度数目,比如原来是两个维度,
复制后的目标只有一个维度就不行了'''
# print(box_a.repeat(1).shape)

 

def rknn_backward(cache_list: List[Dict], window_size: int, param: Dict) -> Tuple[torch.Tensor, ...]: # 初始化梯度 (保持不变) grad_L1_w = torch.zeros_like(cache_list[0]['Linear1_weight']) grad_L1_b = torch.zeros_like(cache_list[0]['Linear1_bias']) grad_L2_w = torch.zeros_like(cache_list[0]['Linear2_weight']) grad_L2_b = torch.zeros_like(cache_list[0]['Linear2_bias']) grad_L3_w = torch.zeros_like(cache_list[0]['Linear3_weight']) grad_L3_b = torch.zeros_like(cache_list[0]['Linear3_bias']) # 反向传播通过时间窗口 for cache in cache_list: # 确保所有输入是 PyTorch 张量 pos_error = cache['pos_error'] weight3 = cache['Linear3_weight'] # 计算输出梯度并确保维度为1×2 d_output = (-pos_error * 0.1) # 基础计算 # 转换为1×2矩阵 if d_output.dim() == 0: # 标量 d_output = d_output.reshape(1, 1) # 转为1×1 d_output = d_output.expand(1, 2) # 复制为1×2 else: d_output = d_output.view(1, -1) # 转为行向量 if d_output.size(1) < 2: # 如果列数不足2 d_output = d_output.repeat(1, 2) # 重复列 # 确保权重是二维 if weight3.dim() == 1: weight3 = weight3.unsqueeze(1) # [n] -> [n, 1] # ===== 关键修正部分 ===== # 1. 计算隐藏层梯度 (反向传播到第三层输入) # 正确公式: d_hidden2 = d_output @ weight3.t() d_hidden2 = d_output @ weight3 # (1×2) @ (2×hidden_dim) -> (1×hidden_dim) # 2. 计算第三层参数梯度 # 权重梯度: ∂L/∂W3 = (d_output)^T @ h2 grad_L3_w += d_output.t() @ cache['hidden'] # (1) @ (1×hidden_dim) -> (2×hidden_dim) # 偏置梯度: ∂L/∂b3 = sum(d_output, dim=0) grad_L3_b += d_output.squeeze(0) # 从(1×2)转为(2) # 3. 应用ReLU激活函数的导数 # 计算第二层输出的梯度: d_linear2 = d_hidden2 * relu'(z2) # relu'(z) = 1 if z > 0 else 0 relu_deriv = (cache['hidden_prev'] > 0).float() # ReLU导数 d_linear2 = d_hidden2 * relu_deriv # 4. 计算第二层参数梯度 # 权重梯度: ∂L/∂W2 = (d_linear2)^T @ h1 grad_L2_w += d_linear2.t() @ cache['input'] # (hidden_dim×1) @ (1×input_dim) -> (hidden_dim×input_dim) # 偏置梯度: ∂L/∂b2 = sum(d_linear2, dim=0) grad_L2_b += d_linear2.squeeze(0) # 从(1×hidden_dim)转为(hidden_dim) # 5. 继续反向传播到第一层 # 计算第一层输入的梯度: d_hidden1 = d_linear2 @ weight2 weight2 = cache['Linear2_weight'] if weight2.dim() == 1: weight2 = weight2.unsqueeze(1) d_hidden1 = d_linear2 @ weight2 # (1×hidden_dim) @ (hidden_dim×input_dim) -> (1×input_dim) # 应用ReLU激活函数的导数到第一层 relu_deriv1 = (cache['input'] > 0).float() # 第一层ReLU导数 d_linear1 = d_hidden1 * relu_deriv1 # 6. 计算第一层参数梯度 # 权重梯度: ∂L/∂W1 = (d_linear1)^T @ input # 注意: 这里cache['input']是原始输入 grad_L1_w += d_linear1.t() @ cache['input'] # (input_dim×1) @ (1×input_dim) -> (input_dim×input_dim) # 偏置梯度: ∂L/∂b1 = sum(d_linear1, dim=0) grad_L1_b += d_linear1.squeeze(0) # 从(1×input_dim)转为(input_dim) return grad_L1_w, grad_L1_b, grad_L2_w, grad_L2_b, grad_L3_w, grad_L3_b这一部分代码报错,帮我修改
最新发布
07-17
### PyTorch 中 `torch.gather` 和 `repeat` 的搭配用法 #### 什么是 `torch.gather`? `torch.gather` 是 PyTorch 中的一个操作,用于从输入张量的不同维度收集指定索引位置上的元素。它允许按照给定的索引来提取特定的数据片段。 其基本语法为: ```python torch.gather(input, dim, index, out=None, sparse_grad=False) ``` 其中: - `input`: 输入张量。 - `dim`: 要沿着哪个维度进行聚集操作。 - `index`: 指定要采集的索引值,形状需与目标一致[^1]。 #### 什么是 `repeat` 或 `.expand()` 方法? `.repeat()` 是一种扩展张量的方法,可以重复某个张量的内容以匹配新的形状。它的作用类似于 NumPy 的 `tile` 函数。对于简单的广播机制不足的情况,可以通过 `.repeat()` 来实现更复杂的形状调整。 其基本语法为: ```python tensor.repeat(*sizes) ``` 参数 `*sizes` 表示各个维度上需要重复的次数[^2]。 --- #### 组合使用场景分析 当我们将 `torch.gather` 和 `repeat` 结合起来时,通常是为了处理以下情况之一: 1. **动态索引选择并复制**:先通过 `gather` 收集某些特定数据,再利用 `repeat` 将这些数据沿某一维度扩展到更大的规模。 2. **批量操作中的灵活映射**:在批处理模式下,针对不同样本执行个性化的特征抽取或变换。 下面是一个具体的例子展示如何组合这两个功能。 --- #### 使用案例 假设有一个二维矩阵 A,以及一组对应每一行的最大值索引 B。我们需要根据索引 B 提取每行最大值,并将其扩展成一个新的三维张量 C。 ##### 示例代码 ```python import torch # 创建一个随机二维张量 A (batch_size=3, feature_dim=4) A = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.float) # 假设我们知道每行的最大值对应的列索引 B = torch.tensor([3, 2, 1]) # shape: (batch_size,) -> [3, 2, 1] # Step 1: 使用 gather 获取每行的最大值 max_values_per_row = torch.gather(A, 1, B.unsqueeze(1)) # unsqueeze 使 B 变为 (batch_size, 1) print("Max values per row:", max_values_per_row.squeeze()) # Output: tensor([4., 7., 10.]) # Step 2: 扩展结果至更高维空间 C = max_values_per_row.expand(-1, 4).unsqueeze(2) # expand(batch_size, 4), 并增加一维变为 (batch_size, 4, 1) D = C.repeat(1, 1, 3) # 进一步 repeat(batch_size, 4, 3) print(D.shape) # 输出应为 torch.Size([3, 4, 3]) ``` 上述代码展示了如何结合 `torch.gather` 和 `repeat` 完成复杂的数据转换任务。这里的关键在于理解 `gather` 如何定位所需数据,以及 `repeat` 怎样改变张量结构[^3]。 --- #### 注意事项 1. 当调用 `gather` 时,确保索引张量 (`index`) 的尺寸与目标维度相兼容。 2. 对于高阶张量的操作,务必注意各维度的意义及其顺序关系。 3. 如果涉及 GPU 计算,请确认所有参与运算的对象均位于同一设备之上[^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值