张量是深度学习框架(如 PyTorch、TensorFlow)的核心数据结构,其操作涵盖数学运算、形状调整、索引切片、广播机制等。以下从 操作类型 和 应用场景 两个维度解析核心操作:
一、数学运算
1. 逐元素运算
对张量中的每个元素独立计算,输入和输出形状相同。
- 示例:
a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # 加法 c = a + b # tensor([5, 7, 9]) # 乘法 d = a * b # tensor([4, 10, 18]) # 函数运算(如指数) e = torch.exp(a) # tensor([2.7183, 7.3891, 20.0855])
2. 矩阵运算
涉及线性代数运算,如矩阵乘法、转置、求逆等。
- 矩阵乘法:
x = torch.tensor([[1, 2], [3, 4]]) # shape (2,2) y = torch.tensor([[5, 6], [7, 8]]) # shape (2,2) z = torch.matmul(x, y) # [[19, 22], [43, 50]]
- 转置:
x_t = x.T # [[1, 3], [2, 4]]
二、形状操作
1. 维度重塑
-
view()
与reshape()
:x = torch.arange(6) # shape (6,) y = x.view(2, 3) # [[0,1,2], [3,4,5]](共享内存) z = x.reshape(3, 2) # 自动处理非连续张量
2. 维度交换
-
permute()
:重新排列维度顺序。x = torch.rand(2, 3, 4) # shape (2,3,4) y = x.permute(1, 0, 2) # shape (3,2,4)
3. 压缩/扩展维度
-
squeeze()
与unsqueeze()
:x = torch.rand(1, 3, 1, 5) y = x.squeeze() # 移除所有大小为1的维度 → (3,5) z = y.unsqueeze(0) # 在0维添加新维度 → (1,3,5)
三、索引与切片
1. 基础索引
- 规则:与 NumPy 类似,支持整数、切片、布尔掩码。
x = torch.rand(4, 5) # shape (4,5) # 取第1行 row_1 = x[0, :] # shape (5,) # 取前两行、第2列之后的数据 sub_tensor = x[:2, 2:] # shape (2,3) # 布尔索引 mask = x > 0.5 selected = x[mask] # 一维张量
2. 高级索引
-
gather()
:按索引从指定维度收集数据。indices = torch.tensor([[0, 1], [2, 0]]) y = x.gather(1, indices) # 从每行取对应列的值
四、广播(Broadcasting)
1. 规则
- 从后向前逐维度比较,若维度大小 相等 或 其中一个为1,则允许广播。
- 示例:
a = torch.tensor([[1], [2], [3]]) # shape (3,1) b = torch.tensor([4, 5, 6]) # shape (3) c = a + b # 广播后 a→(3,3), b→(3,3)
2. 显式广播
-
expand()
:显式扩展维度。a = torch.tensor([1, 2, 3]) # shape (3,) a_expanded = a.expand(2, 3) # shape (2,3) → [[1,2,3], [1,2,3]]
五、自动微分(Autograd)
1. 梯度跟踪
-
requires_grad=True
启用梯度跟踪。x = torch.tensor(2.0, requires_grad=True) y = x ** 2 + 3 * x y.backward() # 计算梯度 print(x.grad) # dy/dx = 2x +3 → 7.0
2. 梯度禁用
-
with torch.no_grad()
:临时关闭梯度计算以优化性能。with torch.no_grad(): y = x * 2 # 不记录计算图
六、拼接与分割
1. 拼接(Concatenation)
-
torch.cat()
:沿指定维度拼接张量。a = torch.tensor([[1, 2], [3, 4]]) # shape (2,2) b = torch.tensor([[5, 6]]) # shape (1,2) c = torch.cat([a, b], dim=0) # shape (3,2)
2. 分割(Splitting)
-
torch.split()
:按块大小或数量分割。x = torch.arange(10) # shape (10,) chunks = torch.split(x, 3) # [tensor([0,1,2]), tensor([3,4,5]), ...]
七、总结与场景指南
操作类型 | 典型场景 | 注意事项 |
---|---|---|
逐元素运算 | 激活函数(如 torch.relu() ) | 确保输入形状一致 |
矩阵运算 | 全连接层、注意力机制 | 检查维度匹配(如 matmul 的维度对齐) |
形状操作 | 输入预处理(如展平图像) | 注意内存连续性和元素总数一致性 |
广播 | 不同形状张量的算术运算(如添加偏置) | 显式广播避免意外行为 |
自动微分 | 自定义损失函数或网络层 | 及时清零梯度(x.grad.zero_() ) |
示例代码整合:
# 创建张量并启用梯度
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
# 矩阵运算与形状调整
y = torch.matmul(x, x.T) # shape (2,2)
z = y.view(4) # shape (4,)
# 自动微分
z.sum().backward()
print(x.grad) # 梯度计算
掌握这些核心操作,能够高效处理数据流、构建复杂模型并优化计算性能。建议结合实际项目(如图像分类、序列建模)深化理解。