一、torch.mul
该乘法可简单理解为矩阵各位相乘,一个常见的例子为向量点乘,源码定义为torch.mul(input,other,out=None)。其中other可以为一个数也可以为一个张量,other为数即张量的数乘。
该函数可触发广播机制(broadcast)。
tensor1 = 2*torch.ones(1,4)
tensor2 = 3*torch.ones(4,1)
print(torch.mul(tensor1, tensor2))
#输出结果为:
tensor([[6., 6., 6., 6.],
[6., 6., 6., 6.],
[6., 6., 6., 6.],
[6., 6., 6., 6.]])
二、torch.mm
这是我们在线性代数课程中学习的矩阵乘法。该函数源码定义为torch.mm(input,mat2,out=None) ,参数与返回值均为tensor形式。
a=torch.ones(4,3)
b=2*torch.ones(3,2)
c=torch.empty(4,2)
torch.mm(a,b,out=c)
print(torch.mm(a,b))
print( c )
#输出结果为
tensor([[6., 6.],
[6., 6.],
[6., 6.],
[6., 6.]])
tensor([[6., 6.],
[6., 6.],
[6., 6.],
[6., 6.]])
三、torch.matmul
这个矩阵乘法是在torch.mm的基础上增加了广播机制,源码定义为torch.matmul(input,other,o

本文详细介绍了PyTorch中的三个矩阵乘法操作:torch.mul主要涉及元素级乘法,支持广播机制;torch.mm对应线性代数中的矩阵乘法;torch.matmul在torch.mm基础上增加广播机制,适用于不同形状的张量,包括一维和三维情况。
最低0.47元/天 解锁文章
1123

被折叠的 条评论
为什么被折叠?



