PyTorch矩阵乘法函数区别解析与矩阵高级索引说明——《动手学深度学习》3.6.3、3.6.4和3.6.5 (P79)

部署运行你感兴趣的模型镜像

主要区别总结

函数输入要求输出维度支持广播使用场景
torch.matmul灵活灵活通用矩阵乘法
torch.mm两个2D张量2D严格矩阵乘法
torch.mv2D矩阵 × 1D向量1D矩阵向量乘法

推荐用法

  • 大多数情况:使用 torch.matmul@ 运算符(A @ B

  • 性能关键且确定维度:使用特定函数(mm/mv

  • 批量运算:必须使用 torch.matmul

# 现代PyTorch中推荐使用 @ 运算符
result = A @ B  # 等同于 torch.matmul(A, B)

torch.matmul 是最通用的选择,而 mmmv 或dot在你知道确切维度时可以提供更清晰的代码意图。

矩阵高级索引

 y_hat[[0, 1], y] 是 PyTorch 中的高级索引(advanced indexing)操作

代码解析:

y = torch.tensor([0, 2])           # 形状: (2,)
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])  # 形状: (2, 3)

result = y_hat[[0, 1], y]          # 结果: tensor([0.1000, 0.5000])

索引操作详解:

这个操作相当于:

  • y_hat[0, y[0]]y_hat[0, 0]0.1

  • y_hat[1, y[1]]y_hat[1, 2]0.5

逐步分解:

# 行索引: [0, 1]   列索引: [0, 2]
# 对应位置配对:
#   第1对: 行0, 列0 → 值0.1
#   第2对: 行1, 列2 → 值0.5

关于向量类型:

y_hat 是矩阵(2D张量):

print(y_hat.shape)  # torch.Size([2, 3])
# 这是一个 2×3 的矩阵,不是行向量也不是列向量

y 是1D张量:

print(y.shape)      # torch.Size([2])
# 这是一个包含2个元素的一维张量

结果是1D张量:

print(result.shape) # torch.Size([2])
# 结果是一个包含2个元素的一维张量

这种索引的实用场景:

这在机器学习中很常见,特别是在计算交叉熵损失时:

# 假设:
# y_hat 是预测的概率分布 (batch_size, num_classes)
# y 是真实标签 (batch_size,)

# 这种索引用于获取每个样本对应真实标签的预测概率
predicted_probs = y_hat[range(len(y)), y]
# 这在交叉熵损失计算中很有用

总结:

  • y_hat[[0, 1], y]配对索引操作

  • 返回的是每个 (行, 列) 对对应的元素

  • y_hat2D矩阵y1D向量,结果是1D向量

  • 这种操作在机器学习中常用于根据真实标签索引预测概率

最大值的索引位置

这是一个在机器学习和深度学习中使用非常频繁的代码片段。

基本含义

y_hat.argmax(axis=1) 表示:在第二个维度(axis=1)上找出最大值的索引位置

具体解释

在分类问题中的典型用法

假设 y_hat 是一个预测概率矩阵:

  • 每一行代表一个样本

  • 每一列代表一个类别的预测概率

import numpy as np

# 示例:3个样本,4个类别的预测概率
y_hat = np.array([
    [0.1, 0.2, 0.6, 0.1],  # 样本1:第3个类别概率最高(0.6)
    [0.7, 0.1, 0.1, 0.1],  # 样本2:第1个类别概率最高(0.7)
    [0.05, 0.05, 0.1, 0.8] # 样本3:第4个类别概率最高(0.8)
])

predictions = y_hat.argmax(axis=1)
print(predictions)  # 输出:[2, 0, 3]

维度说明

  • axis=0:沿着行方向(垂直)

  • axis=1:沿着列方向(水平)

对于二维数组:

[[x00, x01, x02],  ← axis=1(列方向)
 [x10, x11, x12],
 [x20, x21, x22]]
 ↑
axis=0(行方向)

实际应用场景

# 在神经网络分类中
import torch

# 模型输出(批量大小=4,类别数=3)
outputs = torch.tensor([
    [1.2, 0.5, -0.3],
    [0.1, 2.1, 0.8],
    [-0.5, 0.3, 1.7],
    [0.9, 0.6, 0.4]
])

# 获取预测类别
predicted_classes = outputs.argmax(dim=1)  # PyTorch中用dim
print(predicted_classes)  # 输出:tensor([0, 1, 2, 0])

总结

y_hat.argmax(axis=1) 的主要作用是:

  • 将概率分布转换为具体的类别预测

  • 找出每个样本最可能的类别

  • 常用于计算准确率和模型评估

这是在分类任务中从模型输出获取最终预测结果的常用方法。

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值