一、GTP2Block 整体结构
1.1 block准备
import torch
from torch import nn
from transformers import GPT2Model, GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
cfg = GPT2Config()
print(cfg.add_cross_attention)
blk = GPT2Block(cfg, layer_idx=0)
hidden_states = torch.randn(10, 1024, 768)
1.2 block架构
经典的preNorm TFDecoder架构
GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(act): NewGELUActivation()
(dropout): Dropout(p=0.1, inplace=False)
)
)
1.3 forward-preNorm
y = attn(ln_1(x)) + x
O = mlp(ln_2(y)) + y
二、GPT2Attention
- hidden 拆分成 q k v:
query, key, value = gpt2_att.c_attn(hidden_states).split(split_size, dim=2)
- q k v 拆分成多头
query = gpt2_att._split_heads(query, gpt2_att.num_heads, gpt2_att.head_dim)
key = gpt2_att._split_heads(key, gpt2_att.num_heads, gpt2_att.head_dim)
value = gpt2_att._split_heads(value, gpt2_att.num_heads, gpt2_att.head_dim)
print(f'{
query.shape=}') # [batch, n_head, len, head_emb]
- 计算attention
- A ^