torch.mv 用法

在 PyTorch 中,torch.mv 是专门用于矩阵与向量相乘的高效函数。以下是详细使用指南:


一、核心功能

# 数学等价公式
输出向量 = 矩阵 [m×n] × 向量 [n]
# 结果维度:向量 [m]
result = torch.mv(matrix, vector)

二、参数规范

参数要求示例维度
matrix必须是 2D 张量 (矩阵)[5, 3], [2, 100]
vector必须是 1D 张量 (向量)[3], [100]
import torch

# 正确用法
matrix = torch.randn(3, 4)  # 3x4 矩阵
vector = torch.randn(4)     # 4 维向量
output = torch.mv(matrix, vector)  # 结果维度 [3]

# 错误用法示例
try:
    wrong_vec = torch.randn(4, 1)  # 2D 张量(列向量)
    torch.mv(matrix, wrong_vec)    # 报错:vector必须是1D
except RuntimeError as e:
    print(f"错误信息: {e}")

三、与其他乘法操作对比

函数用途输入维度要求
torch.mv()矩阵 × 向量matrix [m×n], vector [n]
torch.mm()矩阵 × 矩阵[m×n] × [n×p] → [m×p]
torch.matmul()广义矩阵乘法(支持广播)混合维度
# 用 mm 实现 mv 的效果(需要维度转换)
matrix = torch.randn(3, 4)
vector = torch.randn(4, 1)          # 列向量(2D)
result_mm = torch.mm(matrix, vector).squeeze()  # 结果维度 [3]

# 用 matmul 的等价写法
result_matmul = torch.matmul(matrix, vector)    # 自动处理1D/2D

四、性能优化技巧

  1. 数据类型对齐

    # 确保矩阵和向量数据类型一致
    matrix = matrix.to(torch.float32)
    vector = vector.to(torch.float32)
    
  2. 设备一致性

    # 同时位于 GPU 时加速计算
    if torch.cuda.is_available():
        matrix = matrix.cuda()
        vector = vector.cuda()
    
  3. 避免不必要拷贝

    # 对非连续内存数据进行优化
    if not matrix.is_contiguous():
        matrix = matrix.contiguous()
    

五、实际应用场景

1. 线性变换计算
# 模拟全连接层前向传播
weights = torch.randn(256, 784)  # 全连接层权重
input_vec = torch.randn(784)     # 输入向量
output = torch.mv(weights, input_vec)  # 输出 [256]
2. 注意力机制计算
# 简化版注意力得分计算
query = torch.randn(10, 64)      # 10个查询向量
key_vector = torch.randn(64)     # 单个键向量
attention_scores = torch.mv(query, key_vector)  # [10]

六、常见错误排查

  1. 维度不匹配

    # 矩阵列数必须等于向量长度
    matrix = torch.randn(3, 5)
    vector = torch.randn(4)       # 长度不等于5
    try:
        torch.mv(matrix, vector)
    except RuntimeError as e:
        print(f"维度错误: {e}")  # 显示尺寸不匹配信息
    
  2. 误用2D向量

    # 修正列向量为1D
    column_vector = torch.randn(5, 1)
    corrected_vector = column_vector.squeeze()  # 转为 [5]
    

七、进阶用法

# 批量矩阵向量乘法 (需使用 matmul)
batch_matrix = torch.randn(100, 3, 4)  # 100个3x4矩阵
batch_vectors = torch.randn(100, 4)   # 100个4维向量

# 使用 einsum 表示法
result = torch.einsum('bij,bj->bi', batch_matrix, batch_vectors)  # [100, 3]

总结建议

  • 优先使用场景:当明确第二个操作数是1D向量时,torch.mv 是最高效的选择
  • 灵活性需求:需要处理混合维度时改用 torch.matmul
  • 调试技巧:使用 print(tensor.shape) 确保维度匹配
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值