简单来说,TransformerEncoderLayer和DecoderOnly的大模型用的TransformerEncoder都是过完self MHA后过FFN(但有如下列的若干区别,mask当然首当其冲);TransformerDecoder比TransformerEncoder多了cross MHA,当然mask也有较大的变化
TransformerEncoderLayer和Qwen2DecoderLayer的区别
特性 | TransformerEncoderLayer | Qwen2DecoderLayer |
---|---|---|
Mask区别 | 因果掩码 | 双向掩码 |
归一化方式 | LayerNorm | RMSNorm |
残差位置 | 原始的TransformerEncoder中是PostNorm | Qwen2使用Qwen2RMSNorm,而且开始PreNorm和PostNorm都用 |
激活函数 | ReLU | GeLU/SiLU |
位置编码 | sin/cos | RoPE->YARN |
MLP | up到4倍再down | Qwen2MLP中乘了gate_proj |
… |
贴一下关键的代码
- TransformerEncoderLayer来自torch的实现/opt/miniconda3/envs/torch20/lib/python3.8/site-packages/torch/nn/modules/transformer.py 的简化版
- Qwen2DecoderLayer来自hf的实现/opt/miniconda3/envs/torch20/lib/python3.8/site-packages/transformers/models/qwen2/modeling_qwen2.py
简化版的TransformerEncoderLayer
Dropout完再过LN的residual,一层Dropout了三次
import torch
import torch.nn as nn
import math
class MultiheadAttn(nn.Module):
def __init__(self, dim, nheads):
super(MultiheadAttn, self).__init__()
self.dim = dim
self.nheads = nheads
self.head_dim = dim // nheads
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def forward(self, query, key, value, attn_mask=None):
bs, qlen, dim = query.shape
q = query.reshape(bs, qlen, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, qlen, self.head_dim)
k = key.reshape(bs, qlen, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, qlen, self.head_dim)
v = value.reshape(bs, qlen, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, qlen, self.head_dim)
attn = torch.matmul(q,k.transpose(2,3)) / math.sqrt(self.head_dim)
if attn_mask is not None:
attn = attn.masked_fill(attn_mask == True, float('-inf'))
attn = attn.softmax(dim=-1)
output = torch.matmul(attn, v)
output = self.o_proj(output.transpose(1,2).reshape(bs, qlen, dim))
return output, attn
class PoswiseFeedforwardNet(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.ffn = nn.Sequential(
nn.Linear(dim, 4*dim),
nn.ReLU(),
nn.Dropout(p=0.1),
nn.Linear(4*dim, dim)
)
self.ln = nn.LayerNorm(dim)
def forward(self, input):
return self.ffn(input)
class MyTransformerEncoderLayer(nn.Module):
def __init__(self, dim, nheads):
super().__init__()
self.msa = MultiheadAttn(nheads=nheads, dim=dim)
self.ffn = PoswiseFeedforwardNet(dim = dim)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x, attn_mask=None):
attn_output,x_attn = self.msa(x, x, x, attn_mask)
x = self.norm1(x+self.dropout1(attn_output))
ffn_output = self.ffn(x)
x = self.norm2(x+self.dropout2(ffn_output))
return x
class MyTransformerEncoder(nn.Module):
def __init__(self, dim, nheads, nlayers):
super().__init__()
self.layers = nn.ModuleList([MyTransformerEncoderLayer(dim, nheads) for i in range(nlayers)])
def forward(self, x, attn_mask=None):
for layer in self.layers:
x = layer(x, attn_mask)
return x
if __name__ == '__main__':
embed_dim,num_heads=256,8
q_len,bs = 2,3
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
query = torch.ones(bs, q_len, embed_dim)
key = torch.ones(bs, q_len, embed_dim)
value = torch.ones(bs, q_len, embed_dim)
attn_mask = torch.ones(bs*num_heads, q_len, q_len).eq(1)
attn_output, attn_output_weights = multihead_attn(query, key, value, attn_mask=attn_mask)
print('attn_output={}'.format(attn_output.shape))
print('attn_output_weights={}'.format(attn_output_weights.shape))
print('--------------')
my_multihead_attn = MultiheadAttn(embed_dim, num_heads)
attn_mask = attn_mask.reshape(bs, num_heads, q_len, q_len)
my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value, attn_mask=attn_mask)
print('my_attn_output={}'.format(attn_output.shape))
print('my_attn_output_weights={}'.format(attn_output_weights.shape))
my_transformer_encoder = MyTransformerEncoder(dim=embed_dim, nheads=num_heads, nlayers=3)
print('my_transformer_encoder_output.shape={}'.format(my_transformer_encoder(query).shape))
Qwen2DecoderLayer
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
TransformerEncoderLayer和TransformerDecoderLayer的区别
来一份demo的代码,注意下mask的方式:
import torch.nn as nn
import torch
class PoswiseFeedForwardNet(nn.Module):
def __init__(self, dim):
super(PoswiseFeedForwardNet, self).__init__()
self.fc = nn.Sequential(
nn.Linear(dim, 4*dim),
nn.GeLU(),
nn.Linear(4*dim, dim)
)
self.ln = nn.LayerNorm(dim=dim)
def forward(self, x):
return self.ln(x+self.fc(x))
# 单个Decoder层的网络
class DecoderLayer(nn.Module):
def __init__(self):
super(DecoderLayer, self).__init__()
self.dec_self_attn = MultiHeadAttention()
self.dec_enc_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
# dec_inputs: [batch_size, tgt_len, d_model]
# enc_outputs: [batch_size, src_len, d_model]
# dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
# dec_enc_attn_mask: [batch_size, tgt_len, src_len]
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
# dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
# Q自于Decoder,K和V来自于Encoder里面即可,Query为查询向量
# dec_outputs: [batch_size, tgt_len, d_model]
# dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
dec_outputs = self.pos_ffn(dec_outputs)
# dec_outputs: [batch_size, tgt_len, d_model]
return dec_outputs, dec_self_attn, dec_enc_attn
# Decoder的整个网络
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model)
self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
def forward(self, dec_inputs, enc_inputs, enc_outputs):
# dec_inputs: [batch_size, tgt_len], enc_intpus: [batch_size, src_len], enc_outputs: [batsh_size, src_len, d_model]
dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
dec_outputs = self.pos_emb(dec_outputs) # [batch_size, tgt_len, d_model]
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
dec_self_attn_seq_mask = get_attn_seq_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask+dec_self_attn_seq_mask), 0) # [batch_size, tgt_len, tgt_len]
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # 因为是dec_enc_attn_mask,所以tgt_len是行,也就是[batc_size, tgt_len, src_len]
dec_self_attns, dec_enc_attns = [], []
for layer in self.layers:
# dec_outputs: [batch_size, tgt_len, d_model]
# dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
# dec_enc_attn: [batch_size, n_heads, tgt_len, src_len]
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
return dec_outputs, dec_self_attns, dec_enc_attns
def get_attn_pad_mask(seq_q, seq_k):
'''
seq_q: [batch_size, len_q]
seq_k: [batch_size, len_k]
seq_len could be src_len or it could be tgt_len
seq_len in seq_q and seq_len in seq_k maybe not equal
'''
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.masked_fill(seq_k == 0, float('-inf')).unsqueeze(1) # [batch_size, 1, len_k], 0 is masked
return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]
def get_attn_seq_mask(seq):
'''
seq: [batch_size, tgt_len],例如batch_size=3, tgt_len=4,返回:
tensor([[[0., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.],
[0., 0., 0., 0.]],
[[0., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.],
[0., 0., 0., 0.]],
[[0., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.],
[0., 0., 0., 0.]]])
'''
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
# 注释掉的是numpy写法
# subsequence_mask = np.triu(np.ones(attn_shape), k=1) # Upper triangular matrix
# subsequence_mask = torch.from_numpy(subsequence_mask).byte()
subsequence_mask = torch.triu(torch.ones(attn_shape), diagonal=1)
return subsequence_mask