今天看到一行pytorch的代码
import torch
from torch.autograd import Variable
tensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor, requires_grad=True)
v_out = torch.mean(variable*variable)
很理所当然的理解为两个矩阵相乘,但是打印输出看的时候觉得不对
tensor([[ 1., 4.],
[ 9., 16.]], grad_fn=<MulBackward0>)
这里明显做了一个点乘
那如何才能让这两个variable变量做矩阵的乘法呢
print(torch.mm(variable,variable))
tensor([[ 7., 10.],
[15., 22.]], grad_fn=<MmBackward>)
嗯,这样就可以了
本文通过实例解析了PyTorch中矩阵乘法的正确实现方式,区分了点乘与矩阵乘法的不同,并展示了如何使用torch.mm()函数进行矩阵乘法运算。

被折叠的 条评论
为什么被折叠?



