einops版GQA MSA

GQA

import torch
import torch.nn as nn
import math
from einops import rearrange

class MyGQA(nn.Module):
    def __init__(self, nheads, dim, ngroups):
        super().__init__()
        
        self.head_dim = dim // nheads
        self.nheads = nheads
        self.dim = dim
        self.ngroups = ngroups
        self.heads_per_group = nheads // ngroups

        # dim = self.head_dim * nheads
        # dim = self.head_dim * self.heads_per_group * ngroups
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim // self.heads_per_group)
        self.v_proj = nn.Linear(dim, dim // self.heads_per_group)
        self.o_proj = nn.Linear(dim, dim)
        self.ln = nn.LayerNorm(dim)

    def forward(self, query, key, value, attn_mask = None):
        bs,q_len,dim = query.shape

        # q = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim)
        # k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)
        # v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)
        q = rearrange(self.q_proj(query), 'b l (head k) -> b head l k', head=self.nheads)
        k = rearrange(self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups)
        v = rearrange(self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups)

        attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        output = torch.matmul(attn, v) # bs,nheads,q_len,head_dim
        output = self.o_proj(rearrange(output, 'b head l k -> b l (head k)'))
        # output = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim))
        return output, attn

class MyGQA2(nn.Module):
    def __init__(self, nheads, dim, ngroups):
        super().__init__()
        
        self.head_dim = dim // nheads
        self.nheads = nheads
        self.dim = dim
        self.ngroups = ngroups
        self.heads_per_group = nheads // ngroups

        # dim = self.head_dim * nheads
        # dim = self.head_dim * self.heads_per_group * ngroups
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim // self.heads_per_group)
        self.v_proj = nn.Linear(dim, dim // self.heads_per_group)
        self.o_proj = nn.Linear(dim, dim)

    def forward(self, query, key, value, attn_mask = None):
        bs,q_len,dim = query.shape
        q = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim)
        k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)
        v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)

        attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        output = torch.matmul(attn, v) # bs,nheads,q_len,head_dim
        output = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim))
        return output, attn

if __name__ == '__main__':

    embed_dim,num_heads,num_groups=256,8,4
    q_len,bs = 2,3

    query = torch.randn(bs, q_len, embed_dim)
    key = torch.randn(bs, q_len, embed_dim)
    value = torch.randn(bs, q_len, embed_dim)

    my_multihead_attn = MyGQA(num_heads, embed_dim, num_groups)
    for param in my_multihead_attn.parameters():
        param.data.fill_(0.1)
    my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value)
    print('my_attn_output={}'.format(my_attn_output))

    my_multihead_attn2 = MyGQA2(num_heads, embed_dim, num_groups)
    for param in my_multihead_attn2.parameters():
        param.data.fill_(0.1)
    my_attn_output2, my_attn_output_weights2 = my_multihead_attn2(query, key, value)
    print('my_attn_output2={}'.format(my_attn_output2))

    max_diff = torch.max(torch.abs(my_attn_output - my_attn_output2)).item()

    print(torch.equal(my_attn_output_weights, my_attn_output_weights2))
    print('max_diff={}'.format(max_diff))

MSA


import torch.nn as nn
import torch
from torch import Tensor
import math
from einops import rearrange

class MyMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MyMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.q_proj = nn.Linear(embed_dim,embed_dim)
        self.k_proj = nn.Linear(embed_dim,embed_dim)
        self.w_proj = nn.Linear(embed_dim,embed_dim)
        self.fc = nn.Linear(embed_dim,embed_dim)

    def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor, attn_mask = None):
        bs, q_len, head_dim = q.shape
        q = q / math.sqrt(head_dim)
        # (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
        attn = torch.bmm(q, k.transpose(-2, -1))
        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        # (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
        output = torch.bmm(attn, v)
        return output,attn

    def forward(self, query:Tensor, key:Tensor, value:Tensor, attn_mask=None):
        # assert query, key, value have the same shape
        bs, q_len, embed_dim = query.shape
        head_dim = embed_dim // self.num_heads

        q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        k = self.k_proj(key).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        v = self.w_proj(value).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        self_output,attn = self.scaled_dot_product_attention(q, k, v, attn_mask)

        # self_output: bs * num_heads, q_len, head_dim
        # attn: bs * num_heads, q_len, k_len
        output = self.fc(self_output.reshape(bs, self.num_heads, q_len, head_dim).transpose(1,2).reshape(bs, q_len, self.num_heads*head_dim))
        # hugging face版把fc放到BertSelfOutput里去了
        return output,attn
    
class MyMHA(nn.Module):
    def __init__(self, nheads, dim):
        super().__init__()
        self.dim = dim
        self.nheads = nheads
        self.head_dim = dim // nheads
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)
    
    def forward(self, query, key, value, attn_mask = None):
        # query.shape: bs, seq_len, dim
        q = rearrange(self.q_proj(query), 'b l (nheads dim) -> b nheads l dim', nheads = self.nheads)
        k = rearrange(self.k_proj(key), 'b l (nheads dim) -> b nheads dim l', nheads = self.nheads)
        v = rearrange(self.v_proj(value), 'b l (nheads dim) -> b nheads l dim', nheads = self.nheads)

        attn = torch.matmul(q, k) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        output = self.o_proj(rearrange(torch.matmul(attn, v), 'b nheads l dim -> b l (nheads dim)'))

        return output, attn

embed_dim,num_heads=256,8
q_len,bs = 2,3

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
query = torch.randn(bs, q_len, embed_dim)
key = torch.randn(bs, q_len, embed_dim)
value = torch.randn(bs, q_len, embed_dim)
attn_mask = torch.ones(bs*num_heads, q_len, q_len)
for param in multihead_attn.parameters():
    param.data.fill_(0.1)

output, attn = multihead_attn(query, key, value, attn_mask=attn_mask)
print('output={}'.format(output.shape))
print('attn={}'.format(attn.shape))
print('--------------')
my_multihead_attn = MyMultiheadAttention(embed_dim, num_heads)
for param in my_multihead_attn.parameters():
    param.data.fill_(0.1)
output1, attn1 = my_multihead_attn(query, key, value, attn_mask=attn_mask)
print('output1={}'.format(output1.shape))
print('attn1={}'.format(attn1.shape))

print('--------------')
my_multihead_attn2 = MyMHA(num_heads,embed_dim)
for param in my_multihead_attn2.parameters():
    param.data.fill_(0.1)

output2, attn2 = my_multihead_attn2(query, key, value, attn_mask=attn_mask.reshape(bs, num_heads, q_len, q_len))
print('output2={}'.format(output2.shape))
print('attn2={}'.format(attn2.shape))
print('largest_diff2={}'.format(torch.max(torch.abs(output1-output2))))
print('largest_diff1={}'.format(torch.max(torch.abs(output1-output))))

参考文档

  • https://einops.rocks/pytorch-examples.html
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值