1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None),其中mat1(\(n\times m\)),mat2(\(m\times d\)),输出out的维度是(\(n\times d\))。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1(\(b\times n \times m\)),bmat2(\(b\times m \times d\)),输出out的维度是(\(b \times n \times d\))。
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
3. 多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)支持broadcast操作,使用起来比较复杂。
针对多维数据 matmul()乘法,我们可以认为该matmul()乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input(

本文详细介绍了PyTorch中四种常见的矩阵运算方法:二维矩阵乘法torch.mm(), 三维带batch的矩阵乘法torch.bmm(), 多维矩阵乘法torch.matmul(), 以及矩阵逐元素乘法torch.mul()。此外,还对比了运算符@和*在矩阵运算中的不同作用。
最低0.47元/天 解锁文章
2211

被折叠的 条评论
为什么被折叠?



