torch.bmm() 与 torch.matmul()

本文详细解析了PyTorch中的两种矩阵运算方法:torch.bmm()和torch.matmul()。对比了它们在处理不同维度张量时的区别,包括强制维度大小相同的要求、广播机制的应用,以及在3D张量上的等效性。通过具体示例,展示了如何利用这些函数进行高效的张量运算。
该文章已生成可运行项目,

torch.bmm()

torch.matmul()


torch.bmm()强制规定维度和大小相同

torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作

当进行操作的两个tensor都是3D时,两者等同。

torch.bmm()

官网:https://pytorch.org/docs/stable/torch.html#torch.bmm

torch.bmm(inputmat2out=None) → Tensor

 torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。

参数:

input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。

output:输出结果

并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。

例子:

import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())

output:
torch.Size([2, 3, 7])

###############################
y = torch.rand(2,5,7) ##维度不匹配,报错
print(torch.bmm(x,y).size())

output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)

torch.matmul()

torch.matmul(inputotherout=None) → Tensor

 torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。

参数:

input,other:两个要进行操作的tensor结构

output:结果

一些规则约定:

(1)若两个都是1D(向量)的,则返回两个向量的点积

import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D

x = torch.rand(2,4)
y = torch.rand(4,3) ###维度也要对应才可以乘
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())

output:
tensor([[0.9128, 0.8425, 0.7269],
        [1.4441, 1.5334, 1.3273]]) 
 torch.Size([2, 3])

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系。

import torch
x = torch.rand(4) #1D
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())

### 扩充x =>(,4) 
### 相乘x(,4) * y(4,3) =>(,3) 
### 去掉1D =>(3)

output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430]) 
 torch.Size([3])

(4)若input是2D,other是1D,则返回两者的点积结果。(个人觉得这块也可以理解成给other添加了维度,然后再去掉此维度,只不过维度是(3, )而不是规则(3)中的( ,4)了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)

import torch
x = torch.rand(3) #1D
y = torch.rand(4,3) #2D
print(torch.matmul(y,x),'\n',torch.matmul(y,x).size()) #2D*1D
 
output:
torch.Size([3])
torch.Size([4, 3])
tensor([0.8278, 0.5970, 1.0370, 0.2681]) 
 torch.Size([4])

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)。

        (a)若input是1D,other是大于2D的,则类似于规则(3)。

import torch
x = torch.randn(2, 3, 4)
y = torch.randn(3)
print(torch.matmul(y, x),'\n',torch.matmul(y, x).size()) #1D*3D

output:
tensor([[-0.9747, -0.6660, -1.1704, -1.0522],
        [ 0.0901, -1.5353,  1.5601, -0.0252]]) 
 torch.Size([2, 4])

        (b)若other是1D,input是大于2D的,则类似于规则(4)。

import torch
x = torch.randn(2, 3, 4)
y = torch.randn(4)

print(torch.matmul(x, y),'\n',torch.matmul(x, y).size()) # 3D*1D

output:
tensor([[ 0.6217, -0.1259, -0.2377],
        [ 0.6874,  0.0733,  0.1793]]) 
 torch.Size([2, 3])

            (c)若input和other都是3D的,则与torch.bmm()函数功能一样。

import torch
x = torch.randn(2,2,4)
y = torch.randn(2,4,5)

print(torch.matmul(x, y).size(),'\n',torch.bmm(x, y).size())
print(torch.equal(torch.matmul(x,y),torch.bmm(x,y)))

output:
torch.Size([2, 2, 5]) 
 torch.Size([2, 2, 5])
True

          (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)*  other (k,m,p)  = output(j,k,n,p)。

import torch
x = torch.randn(10,1,2,4)
y = torch.randn(2,4,5)

print(torch.matmul(x, y).size())

output:
torch.Size([10, 2, 2, 5])

这个例子中,可以理解为x中dim=1这个维度可以扩充(广播),y中可以添加一个维度,然后在进行批乘操作。

本文章已经生成可运行项目
### TensorFlow `tf.matmul` PyTorch `torch.bmm` 的区别用法 #### 基本定义 TensorFlow 中的 `tf.matmul` 是用于矩阵乘法的操作,支持两个二维张量或多维张量之间的矩阵相乘。而 PyTorch 中的 `torch.bmm` 则专门针对三维张量设计,表示批量矩阵乘法 (Batch Matrix Multiplication)[^1]。 #### 输入维度的要求 对于 `tf.matmul` 来说,如果输入的是多于两维的张量,则会将最后两维视为矩阵并执行逐批矩阵乘法操作。其余前置维度会被视作批次大小的一部分处理[^2]。 而在 PyTorch 中,`torch.bmm` 明确要求输入必须是三个维度 `(batch_size, n, m)` `(batch_size, m, p)` 形式的张量,并返回形状为 `(batch_size, n, p)` 的结果[^3]。 #### 性能优化方面 由于 `torch.bmm` 更加专注于特定场景下的高效实现——即当数据集包含多个独立的小型矩阵时可以更有效地利用硬件资源完成计算任务;相比之下,虽然 `tf.matmul` 同样能够胜任此类工作负载但它可能不会像前者那样特别针对某些情况做出额外性能上的考量。 以下是两者的一个简单对比代码示例: ```python import tensorflow as tf import torch # Example using TensorFlow's matmul a_tf = tf.random.uniform((5, 3, 4)) # Batch size of 5, matrices are 3x4 b_tf = tf.random.uniform((5, 4, 6)) # Batch size of 5, matrices are 4x6 result_tf = tf.matmul(a_tf, b_tf) # Result will have shape (5, 3, 6) print("Result from TF:", result_tf.shape) # Equivalent operation in PyTorch with bmm a_pt = torch.rand(5, 3, 4) # Same dimensions but now tensors b_pt = torch.rand(5, 4, 6) result_pt = torch.bmm(a_pt, b_pt) # Also results in a tensor shaped (5, 3, 6) print("Result from PT:", result_pt.size()) ``` 通过以上例子可以看出,在相同逻辑下两种框架提供了相似功能但各有侧重之处。 #### 使用注意事项 需要注意的一点是在实际应用过程中要确保所使用的库版本兼容以及正确设置设备环境(CPU/GPU),因为这可能会显著影响最终程序运行效率甚至可能导致错误发生。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Foneone

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值