实现dual attention

代码实现如下:

import torch
import torch.nn.functional as F
A = torch.ones(3,2)
B = torch.ones(2,2)
a = torch.Tensor([1,2,3]).reshape(3,1)
b = torch.Tensor([4,5]).reshape(2,1)
A = torch.mul(A,a)
B = torch.mul(B,b)
W = torch.rand(6,1)
print('W:',W)
print(A)
print(B)
length_A = A.shape[0]
length_B = B.shape[0]
B_1 = B.unsqueeze(0).expand(length_A,-1,-1)
print('B_1:',B_1)
B_2 = B_1.reshape(-1,2)
print("B_2:",B_2)
A_1 = A.unsqueeze(1).expand(-1,length_B,-1)
print("A_1:",A_1)
A_2 = A_1.reshape(-1,2)
print("A_2:",A_2)
C_col_first = torch.mul(A_2,B_2)
print(C_col_first)
#######################
##finished
conA_B = torch.cat([A_2,B_2],1)
print('conA_B:',conA_B)
final = torch.cat([conA_B,C_col_first],1)
print(final)
att_ori = torch.mm(final,W)
att = att_ori.reshape(length_A,length_B)
print(att)
att_col_softmax = torch.softmax(att,dim=0)
att_row_softmax = torch.softmax(att,dim=1)
print("行归一化:\n",att_row_softmax)
print("列归一化:\n",att_col_softmax)

att_col_softmax = F.softmax(att,dim=0)
att_row_softmax = F.softmax(att,dim=1)
print("行归一化:\n",att_row_softmax)
print("列归一化:\n",att_col_softmax)




执行结果:

W: tensor([[0.4134],
        [0.6084],
        [0.4033],
        [0.0100],
        [0.1279],
        [0.0321]])
tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])
tensor([[4., 4.],
        [5., 5.]])
B_1: tensor([[[4., 4.],
         [5., 5.]],

        [[4., 4.],
         [5., 5.]],

        [[4., 4.],
         [5., 5.]]])
B_2: tensor([[4., 4.],
        [5., 5.],
        [4., 4.],
        [5., 5.],
        [4., 4.],
        [5., 5.]])
A_1: tensor([[[1., 1.],
         [1., 1.]],

        [[2., 2.],
         [2., 2.]],

        [[3., 3.],
         [3., 3.]]])
A_2: tensor([[1., 1.],
        [1., 1.],
        [2., 2.],
        [2., 2.],
        [3., 3.],
        [3., 3.]])
tensor([[ 4.,  4.],
        [ 5.,  5.],
        [ 8.,  8.],
        [10., 10.],
        [12., 12.],
        [15., 15.]])
conA_B: tensor([[1., 1., 4., 4.],
        [1., 1., 5., 5.],
        [2., 2., 4., 4.],
        [2., 2., 5., 5.],
        [3., 3., 4., 4.],
        [3., 3., 5., 5.]])
tensor([[ 1.,  1.,  4.,  4.,  4.,  4.],
        [ 1.,  1.,  5.,  5.,  5.,  5.],
        [ 2.,  2.,  4.,  4.,  8.,  8.],
        [ 2.,  2.,  5.,  5., 10., 10.],
        [ 3.,  3.,  4.,  4., 12., 12.],
        [ 3.,  3.,  5.,  5., 15., 15.]])
tensor([[3.3150, 3.8883],
        [4.9768, 5.7101],
        [6.6386, 7.5320]])
行归一化:
 tensor([[0.3605, 0.6395],
        [0.3245, 0.6755],
        [0.2904, 0.7096]])
列归一化:
 tensor([[0.0294, 0.0220],
        [0.1548, 0.1361],
        [0.8158, 0.8418]])
行归一化:
 tensor([[0.3605, 0.6395],
        [0.3245, 0.6755],
        [0.2904, 0.7096]])
列归一化:
 tensor([[0.0294, 0.0220],
        [0.1548, 0.1361],
        [0.8158, 0.8418]])

Process finished with exit code 0

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值