pytorch中的@
表示的是数学中的矩阵乘法,*
是数学中的Hadamard积(哈达玛积)
import torch
a = torch.tensor([[1,2],
[2,3],
[5,6]])
b = torch.tensor([[2,1],
[8,5],
[3,2]])
c = a*b
d = a @ b.t() ## [3,2] @ [2,3]
print('*',c)
print('@',d)
输出
-
做矩阵乘法时,要求两个矩阵的形状为[m,n]和[n,k],得到的是一个[m,k]的矩阵。
-
做Hadamarda积是,要求两个矩阵是同型矩阵,a和b都是2*3的。相同位置元素进行相乘
在pytorch中.t()
表示二维矩阵的转置
参考