其实和原来的注意力是一样,相当于是对不同帧做了加权,所以这里的帧就是原注意力L(长度)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class cross_frame_attn(nn.Module):
def __init__(self, embed_dim, n_heads, k_size, v_size):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_head_dim = self.embed_dim//self.n_heads
self.to_q = nn.Linear(k_size, self.embed_dim, bias=False)
self.to_k = nn.Linear(k_size, self.embed_dim, bias=False)
self.to_v = nn.Linear(v_size, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.to_out = nn.Linear(self.embed_dim, v_size, bias=False)
def forward(self, x, ):