代码实现如下:
import torch
import torch.nn.functional as F
A = torch.ones(3,2)
B = torch.ones(2,2)
a = torch.Tensor([1,2,3]).reshape(3,1)
b = torch.Tensor([4,5]).reshape(2,1)
A = torch.mul(A,a)
B = torch.mul(B,b)
W = torch.rand(6,1)
print('W:',W)
print(A)
print(B)
length_A = A.shape[0]
length_B = B.shape[0]
B_1 = B.unsqueeze(0).expand(length_A,-1,-1)
print('B_1:',B_1)
B_2 = B_1.reshape(-1,2)
print("B_2:",B_2)
A_1 = A.unsqueeze(1).expand(-1,length_B,-1)
print("A_1:",A_1)
A_2 = A_1.reshape(-1,2)
print("A_2:",A_2)
C_col_first = torch.mul(A_2,B_2)
print(C_col_first)
#######################
##finished
conA_B = torch.cat([A_2,B_2],1)
print('conA_B:',conA_B)
final = torch.cat([conA_B,C_col_first],1)
print(final)
att_ori = torch.mm(final,W)
att = att_ori.reshape(length_A,length_B)
print(att)
att_col_softmax = torch.softmax(att,dim=0)
att_row_softmax = torch.softmax(att,dim=1)
print("行归一化:\n",att_row_softmax)
print("列归一化:\n",att_col_softmax)
att_col_softmax = F.softmax(att,dim=0)
att_row_softmax = F.softmax(att,dim=1)
print("行归一化:\n",att_row_softmax)
print("列归一化:\n",att_col_softmax)
执行结果:
W: tensor([[0.4134],
[0.6084],
[0.4033],
[0.0100],
[0.1279],
[0.0321]])
tensor([[1., 1.],
[2., 2.],
[3., 3.]])
tensor([[4., 4.],
[5., 5.]])
B_1: tensor([[[4., 4.],
[5., 5.]],
[[4., 4.],
[5., 5.]],
[[4., 4.],
[5., 5.]]])
B_2: tensor([[4., 4.],
[5., 5.],
[4., 4.],
[5., 5.],
[4., 4.],
[5., 5.]])
A_1: tensor([[[1., 1.],
[1., 1.]],
[[2., 2.],
[2., 2.]],
[[3., 3.],
[3., 3.]]])
A_2: tensor([[1., 1.],
[1., 1.],
[2., 2.],
[2., 2.],
[3., 3.],
[3., 3.]])
tensor([[ 4., 4.],
[ 5., 5.],
[ 8., 8.],
[10., 10.],
[12., 12.],
[15., 15.]])
conA_B: tensor([[1., 1., 4., 4.],
[1., 1., 5., 5.],
[2., 2., 4., 4.],
[2., 2., 5., 5.],
[3., 3., 4., 4.],
[3., 3., 5., 5.]])
tensor([[ 1., 1., 4., 4., 4., 4.],
[ 1., 1., 5., 5., 5., 5.],
[ 2., 2., 4., 4., 8., 8.],
[ 2., 2., 5., 5., 10., 10.],
[ 3., 3., 4., 4., 12., 12.],
[ 3., 3., 5., 5., 15., 15.]])
tensor([[3.3150, 3.8883],
[4.9768, 5.7101],
[6.6386, 7.5320]])
行归一化:
tensor([[0.3605, 0.6395],
[0.3245, 0.6755],
[0.2904, 0.7096]])
列归一化:
tensor([[0.0294, 0.0220],
[0.1548, 0.1361],
[0.8158, 0.8418]])
行归一化:
tensor([[0.3605, 0.6395],
[0.3245, 0.6755],
[0.2904, 0.7096]])
列归一化:
tensor([[0.0294, 0.0220],
[0.1548, 0.1361],
[0.8158, 0.8418]])
Process finished with exit code 0