【PyTorch】张量(Tensor)核心操作解析

张量是深度学习框架(如 PyTorch、TensorFlow)的核心数据结构,其操作涵盖数学运算、形状调整、索引切片、广播机制等。以下从 ​操作类型​ 和 ​应用场景​ 两个维度解析核心操作:


一、数学运算

1. ​逐元素运算

对张量中的每个元素独立计算,​输入和输出形状相同

  • 示例​:
    a = torch.tensor([1, 2, 3])
    b = torch.tensor([4, 5, 6])
    
    # 加法
    c = a + b                   # tensor([5, 7, 9])
    
    # 乘法
    d = a * b                   # tensor([4, 10, 18])
    
    # 函数运算(如指数)
    e = torch.exp(a)            # tensor([2.7183, 7.3891, 20.0855])
2. ​矩阵运算

涉及线性代数运算,如矩阵乘法、转置、求逆等。

  • 矩阵乘法​:
    x = torch.tensor([[1, 2], [3, 4]])  # shape (2,2)
    y = torch.tensor([[5, 6], [7, 8]])  # shape (2,2)
    z = torch.matmul(x, y)              # [[19, 22], [43, 50]]
  • 转置​:
    x_t = x.T                    # [[1, 3], [2, 4]]

二、形状操作

1. ​维度重塑
  • view() 与 reshape()
    x = torch.arange(6)          # shape (6,)
    y = x.view(2, 3)            # [[0,1,2], [3,4,5]](共享内存)
    z = x.reshape(3, 2)         # 自动处理非连续张量
2. ​维度交换
  • permute():重新排列维度顺序。
    x = torch.rand(2, 3, 4)     # shape (2,3,4)
    y = x.permute(1, 0, 2)      # shape (3,2,4)
3. ​压缩/扩展维度
  • squeeze() 与 unsqueeze()
    x = torch.rand(1, 3, 1, 5)
    y = x.squeeze()             # 移除所有大小为1的维度 → (3,5)
    z = y.unsqueeze(0)          # 在0维添加新维度 → (1,3,5)

三、索引与切片

1. ​基础索引
  • 规则​:与 NumPy 类似,支持整数、切片、布尔掩码。
    x = torch.rand(4, 5)        # shape (4,5)
    
    # 取第1行
    row_1 = x[0, :]             # shape (5,)
    
    # 取前两行、第2列之后的数据
    sub_tensor = x[:2, 2:]      # shape (2,3)
    
    # 布尔索引
    mask = x > 0.5
    selected = x[mask]          # 一维张量
2. ​高级索引
  • gather()​:按索引从指定维度收集数据。
    indices = torch.tensor([[0, 1], [2, 0]])
    y = x.gather(1, indices)    # 从每行取对应列的值

四、广播(Broadcasting)​

1. ​规则
  • 从后向前逐维度比较,若维度大小 ​相等​ 或 ​其中一个为1,则允许广播。
  • 示例​:
    a = torch.tensor([[1], [2], [3]])  # shape (3,1)
    b = torch.tensor([4, 5, 6])        # shape (3)
    c = a + b                        # 广播后 a→(3,3), b→(3,3)
2. ​显式广播
  • expand()​:显式扩展维度。
    a = torch.tensor([1, 2, 3])        # shape (3,)
    a_expanded = a.expand(2, 3)         # shape (2,3) → [[1,2,3], [1,2,3]]

五、自动微分(Autograd)​

1. ​梯度跟踪
  • requires_grad=True 启用梯度跟踪。
    x = torch.tensor(2.0, requires_grad=True)
    y = x ​**​ 2 + 3 * x
    y.backward()              # 计算梯度
    print(x.grad)             # dy/dx = 2x +3 → 7.0
2. ​梯度禁用
  • with torch.no_grad():临时关闭梯度计算以优化性能。
    with torch.no_grad():
        y = x * 2            # 不记录计算图

六、拼接与分割

1. ​拼接(Concatenation)​
  • torch.cat():沿指定维度拼接张量。
    a = torch.tensor([[1, 2], [3, 4]])  # shape (2,2)
    b = torch.tensor([[5, 6]])          # shape (1,2)
    c = torch.cat([a, b], dim=0)        # shape (3,2)
2. ​分割(Splitting)​
  • torch.split():按块大小或数量分割。
    x = torch.arange(10)                # shape (10,)
    chunks = torch.split(x, 3)           # [tensor([0,1,2]), tensor([3,4,5]), ...]

七、总结与场景指南

操作类型典型场景注意事项
逐元素运算激活函数(如 torch.relu()确保输入形状一致
矩阵运算全连接层、注意力机制检查维度匹配(如 matmul 的维度对齐)
形状操作输入预处理(如展平图像)注意内存连续性和元素总数一致性
广播不同形状张量的算术运算(如添加偏置)显式广播避免意外行为
自动微分自定义损失函数或网络层及时清零梯度(x.grad.zero_()

示例代码整合​:

# 创建张量并启用梯度
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)

# 矩阵运算与形状调整
y = torch.matmul(x, x.T)                # shape (2,2)
z = y.view(4)                           # shape (4,)

# 自动微分
z.sum().backward()                      
print(x.grad)                           # 梯度计算

掌握这些核心操作,能够高效处理数据流、构建复杂模型并优化计算性能。建议结合实际项目(如图像分类、序列建模)深化理解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浩瀚之水_csdn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值