import torch.nn as nn
import torch
import math
class Mask_multi_head_self_attention(nn.Module):
def __init__(self, n_heads, d_model):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
# 映射到Q K V 空间
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
# 因为多头,所以最后需要一个映射
self.w_combine = nn.Linear(d_model, d_model)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
b,seq_len,dim = x.shape
head_dim = self.d_model // self.n_heads
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
q = q.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)
k = k.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)
v = v.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)
score = (q @ k.transpose(-1,-2)) / math.sqrt(head_dim)
mask = torch.tril(torch.ones(seq_len, seq_len)) # 下三角矩阵
mask = torch.where(mask==0, float('-inf'), 0)
score = self.softmax(score+mask) @ v
score = score.permute(0,2,1,3).contiguous().view(b, seq_len, -1)
out = self.w_combine(score)
return out
d_model=512
n_head=8
x=torch.rand(5, 100, 512) # b x seq_len x dim
model = Mask_multi_head_self_attention(n_head, d_model)
out = model(x)
print(out.shape)