双线性注意力机制通常用于处理多模态数据,如文本和图像。它通过计算两个模态特征之间的双线性交互,得到一个注意力分布,从而实现对不同模态信息的融合和加权。双线性注意力机制可以更好地捕捉不同模态之间的语义关联,提高模型在多模态任务中的性能。
import torch
import torch.nn as nn
class BilinearAttention(nn.Module):
def __init__(self, v_dim, q_dim, num_hid):
super(BilinearAttention, self).__init__()
self.v_proj = nn.Linear(v_dim, num_hid)
self.q_proj = nn.Linear(q_dim, num_hid)
self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(num_hid, 1)
def forward(self, v, q):
v_proj = self.v_proj(v)
q_proj = self.q_proj(q).unsqueeze(1)
joint_repr = v_proj * q_proj
joint_repr = self.dropout(joint_repr)
logits = self.linear(joint_repr)
att = nn.functional.softmax(logits, dim=1)
return att