GitHub_Trending/gr/grok核心组件探秘:Transformer模型实现原理
【免费下载链接】grok 项目地址: https://gitcode.com/GitHub_Trending/gr/grok
一、Transformer模型架构概览
GitHub_Trending/gr/grok项目中的Transformer模型实现位于grok/transformer.py文件中,采用了Decoder-only架构设计,主要包含以下核心组件:
- 词嵌入层(Embedding):将输入序列转换为向量表示
- 位置编码(Position Encoding):注入序列位置信息
- 解码器(Decoder):由多层DecoderBlock堆叠而成
- 多头注意力机制(Multi-Head Attention):并行计算多个注意力头
- 前馈神经网络(FFN):处理注意力输出的非线性变换
Transformer类核心代码结构
class Transformer(nn.Module):
def __init__(
self,
n_layers: int = 4, # 解码器层数
n_heads: int = 4, # 注意力头数
d_model: int = 256, # 模型维度
dropout: float = 0.1, # Dropout比率
max_context_len: int = 1024,# 最大上下文长度
vocab_len: int = 2000, # 词汇表大小
non_linearity: str = "relu",# 非线性激活函数
weight_noise: float = 0.0, # 权重噪声
) -> None:
super().__init__()
self.embedding = Embedding(vocab_len, d_model, weight_noise=weight_noise)
self.register_buffer("position_encoding", self._position_encoding(max_context_len, d_model))
self.decoder = Decoder(d_model, n_heads, n_layers, dropout, non_linearity, weight_noise=weight_noise)
self.linear = Linear(d_model, vocab_len, bias=False, weight_noise=weight_noise)
二、核心组件详解
2.1 增强版基础层实现
项目对PyTorch原生层进行了扩展,添加了权重噪声功能以增强模型泛化能力:
带噪声的线性层(grok/transformer.py#L17-L35):
class Linear(nn.Linear):
def __init__(self, *args, **kwargs):
self.weight_noise = kwargs.pop("weight_noise")
super().__init__(*args, **kwargs)
def forward(self, input: Tensor) -> Tensor:
if self.weight_noise > 0 and self.training:
bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise
weight = self.weight + torch.randn_like(self.weight) * self.weight_noise
else:
bias = self.bias
weight = self.weight
return F.linear(input, weight, bias)
类似实现还包括带噪声的LayerNorm(grok/transformer.py#L37-L56)和Embedding层(grok/transformer.py#L59-L78),这些增强层在训练时通过添加高斯噪声提高模型的鲁棒性。
2.2 多头注意力机制
多头注意力是Transformer的核心创新点,项目实现位于grok/transformer.py#L81-L174。
注意力头实现:
class AttentionHead(nn.Module):
def __init__(self, d_model: int, d_key: int, weight_noise: float) -> None:
super().__init__()
self.d_key = d_key
self.Wq = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
self.Wk = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
self.Wv = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
self.softmax = nn.Softmax(dim=-1)
def forward(self, queries: Tensor, keys: Tensor, values: Tensor, mask: Union[Tensor, None] = None) -> Tuple[Tensor, ...]:
queries = self.Wq(queries)
keys = self.Wk(keys)
values = self.Wv(values)
# 计算注意力分数
attn = torch.matmul(queries, torch.transpose(keys, -2, -1)) / sqrt(self.d_key)
# 应用掩码(防止未来信息泄露)
if mask is not None:
attn.masked_fill_(mask == 0, float("-inf"))
attn = self.softmax(attn)
result = torch.matmul(attn, values)
return result, attn, values
多头组合机制:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, heads: int, weight_noise: float = 0.0) -> None:
super().__init__()
d_key = int(d_model / heads)
self.attn_heads = nn.ModuleList([
AttentionHead(d_model, d_key, weight_noise=weight_noise)
for _ in range(heads)
])
self.Wo = Linear(d_model, d_model, bias=False, weight_noise=weight_noise)
def forward(self, queries: Tensor, keys: Tensor, values: Tensor, mask: Tensor = None) -> Tuple[Tensor, ...]:
head_outputs = [h(queries, keys, values, mask) for h in self.attn_heads]
head_results = [output[0] for output in head_outputs]
multihead_result = torch.cat(head_results, dim=-1)
return self.Wo(multihead_result), [o[1] for o in head_outputs], [o[2] for o in head_outputs]
2.3 解码器模块
解码器块由"多头注意力+前馈网络"组成,并采用残差连接和层归一化:
DecoderBlock实现(grok/transformer.py#L201-L236):
class DecoderBlock(nn.Module):
def __init__(self, d_model: int, heads: int, dropout: float, non_linearity: str = "relu", weight_noise: float = 0.0) -> None:
super().__init__()
self.self_attn = MultiHeadAttention(d_model, heads, weight_noise=weight_noise)
self.self_attn_norm = LayerNorm(d_model, weight_noise=weight_noise)
self.ffn = FFN(d_model, non_linearity=non_linearity, weight_noise=weight_noise)
self.ffn_drop = nn.Dropout(p=dropout)
self.ffn_norm = LayerNorm(d_model, weight_noise=weight_noise)
def forward(self, x: Tensor, self_attn_mask: Tensor = None) -> Tuple[Tensor, ...]:
# 自注意力子层
a1, layer_attns, layer_values = self.self_attn(x, x, x, self_attn_mask)
a1 = self.self_attn_norm(x + a1) # 残差连接 + 层归一化
# 前馈网络子层
a2 = self.ffn(a1)
a2 = self.ffn_drop(a2)
a2 = self.ffn_norm(a1 + a2) # 残差连接 + 层归一化
return a2, layer_attns, layer_values
前馈网络实现:
class FFN(nn.Module):
def __init__(self, d_model: int, multiplier: int = 4, non_linearity: str = "relu", weight_noise: float = 0.0) -> None:
super().__init__()
d_ff = int(multiplier * d_model)
non_linearities = {"relu": nn.ReLU, "gelu": nn.GELU}
self.ffn = nn.Sequential(
Linear(d_model, d_ff, bias=False, weight_noise=weight_noise),
non_linearities[non_linearity](),
Linear(d_ff, d_model, bias=False, weight_noise=weight_noise),
)
def forward(self, x: Tensor) -> Tensor:
return self.ffn(x)
三、位置编码与掩码机制
3.1 正弦余弦位置编码
项目实现了经典的正弦余弦位置编码,为模型提供序列位置信息:
@classmethod
def _position_encoding(cls, context_len: int, d_model: int) -> Tensor:
rows = [
tensor([
sin(pos / (10000 ** (i / d_model))) if i % 2 == 0
else cos(pos / (10000 ** ((i - 1) / d_model)))
for i in range(d_model)
])
for pos in range(context_len)
]
stack = torch.stack(rows, dim=1)
return stack.T # shape: (context_len, d_model)
3.2 自注意力掩码
为防止解码器看到未来信息,实现了下三角掩码:
@staticmethod
def make_mask(context_len: int) -> Tensor:
return torch.ones([context_len, context_len]).tril() # 下三角矩阵
四、模型训练流程
模型训练入口位于scripts/train.py,核心训练逻辑在grok/training.py的TrainableTransformer类中实现。
4.1 训练配置
通过命令行参数配置模型超参数,关键参数包括:
--n_layers: 解码器层数(默认2)--n_heads: 注意力头数(默认4)--d_model: 模型维度(默认128)--batchsize: 批大小(默认0,自动计算)--max_lr: 最大学习率(默认1e-3)
4.2 前向传播流程
def forward(self, x: Tensor, save_activations: bool = False) -> Tuple[Tensor, ...]:
# 输入处理与位置编码
x = self.embed(x)
# 解码器前向传播
decoded, attentions, values = self.decoder(x, self_attn_mask)
# 输出层投影到词汇表空间
y_hat = self.linear(decoded)
return y_hat, attentions, values
4.3 损失计算
损失函数仅计算等式右侧(答案部分)的交叉熵损失:
# 仅计算等式右侧部分的损失
eq_token_index = self.train_dataset.tokenizer.stoi["="]
eq_position = int(torch.nonzero(y[0, :] == eq_token_index).squeeze())
y_rhs = y[..., eq_position + 1 :]
y_hat_rhs = y_hat[..., eq_position + 1 :]
loss = F.cross_entropy(y_hat_rhs, y_rhs, reduction=reduction)
五、模型应用与扩展
5.1 权重噪声技术
项目特色是在各层实现中加入了权重噪声机制,通过在训练时向权重添加高斯噪声,提高模型的泛化能力和鲁棒性。噪声水平通过--weight_noise参数控制。
5.2 模型可视化
可使用scripts/visualize_metrics.py脚本可视化训练指标和注意力权重,帮助分析模型行为和决策过程。
六、总结
GitHub_Trending/gr/grok项目的Transformer实现具有以下特点:
- 采用Decoder-only架构,适合序列生成任务
- 实现了带权重噪声的增强版基础层,提升模型泛化能力
- 完整实现了多头注意力、位置编码等核心机制
- 针对数学方程求解任务优化了损失计算方式
通过grok/transformer.py和grok/training.py等核心文件的模块化设计,项目提供了一个清晰、可扩展的Transformer实现,适合作为学习和研究Transformer架构的参考案例。
【免费下载链接】grok 项目地址: https://gitcode.com/GitHub_Trending/gr/grok
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



