直接参考官方文档,x1和x2是两个输入,A是参数矩阵,如下表达式
但仔细看实现发现这个表达式并不是简单连乘的关系。假设x1(shape是b,n)和x2(shape是b,m)是二维,那么A是个三维tensor(shape是a,n,m)。具体实现时,A先拆成a个(n,m)形状的tensor,x1分别与之矩阵乘后再点乘(公式里的两次乘法),得到了a个(b,m)形状的tensor,然后在axis=1的维度上求和,最终输出成(b,a)的shape,用爱因斯坦表示法就是bn,anm,bm->ba,下面上一下code:
import torch
import torch.nn as nn
import numpy as np
l = torch.ones(2,5)
A = torch.ones(3,5,4)
r = torch.ones(2,4)
print(torch.einsum('bn,anm,bm->ba', l, A, r))
print(torch.nn.functional.bilinear(l,r,A))
x = torch.ones(2,5)
w = torch.ones(3,5)
print(torch.einsum('ij,kj->ik', x,w))
print(torch.nn.functional.linear(x,w))
print('learn nn.Bilinear')
m = nn.Bilinear(5, 4, 3)
output = m(l, r)
print(output.size())
arr_output = output.data.cpu().numpy()
weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = l.

本文详细解读了PyTorch中Bilinear函数的实现原理,通过实例展示了如何利用官方文档中的公式将二维输入与三维参数矩阵进行复杂的矩阵乘法操作,最终得到(b,a)形状的输出。通过代码演示和数值验证,突出了矩阵乘法与求和在实际任务中的应用。
最低0.47元/天 解锁文章
3982

被折叠的 条评论
为什么被折叠?



