# %%
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple, Dict
# %% [markdown]
# 1. RoPE: 旋转位置编码(完全自实现)
# %%
def precompute_freqs_cis(hidden_size: int, max_seq_len: int, base: int = 10000, num_attention_heads: int = 8) -> torch.Tensor:
"""
预计算复数形式的旋转位置编码 (RoPE)
Args:
hidden_size: 模型维度(如 512)
max_seq_len: 最大序列长度(如 8192)
base: 频率基数,默认 10000
Returns:
freqs_cis: complex tensor of shape (max_seq_len, head_dim // 2)
"""
head_dim = hidden_size // num_attention_heads
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len) # 位置索引 [0, 1, ..., S-1]
freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D//2)
# 将每个频率复制两次,扩展为完整 head_dim 维度
# 因为我们要对每两个相邻维度做旋转:[x0,x1] -> x0*cosθ - x1*sinθ ...
freqs = torch.cat([freqs, freqs], dim=-1) # now (S, D)
# 转为复数角度:cosθ + i*sinθ
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # shape: (max_seq_len, head_dim)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
在最后一个维度上,将后半部分移到前面并取负。
例如:x = [x0, x1, x2, x3] -> [-x2, -x3, x0, x1]
这对应于乘以 i 的操作(复数旋转90度)
"""
x1, x2 = x.chunk(2, dim=-1) # 分成前后两半
return torch.cat((-x2, x1), dim=-1) # 后半取负放前
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
将 RoPE 应用于 query 和 key 张量。
Args:
q: (bsz, n_heads, seq_len, head_dim)
k: (bsz, n_heads, seq_len, head_dim)
freqs_cis: (seq_len, head_dim // 2)
Returns:
q_embed, k_embed
"""
# 提取实部和虚部,并扩展维度以便广播
cos = freqs_cis.real.view(1, 1, freqs_cis.size(0), -1) # (1,1,S,D)
sin = freqs_cis.imag.view(1, 1, freqs_cis.size(0), -1)
# 使用标准公式: q * cos + rotate_half(q) * sin
q_out = (q * cos[:, :, :q.size(2), :]) + (_rotate_half(q) * sin[:, :, :q.size(2), :])
k_out = (k * cos[:, :, :k.size(2), :]) + (_rotate_half(k) * sin[:, :, :k.size(2), :])
return q_out.type_as(q), k_out.type_as(k)
# %% [markdown]
# 2. 自定义注意力层
# %%
class YiziAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
# 用于调试:存储最后一次 attention 权重
self.attn_weights = None
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
bsz, seq_len, _ = x.shape
# 投影到 QKV
query_states = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用 RoPE
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freqs_cis)
# Scaled Dot-Product Attention
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
attn_weights = torch.softmax(attn_weights, dim=-1)
# 存储 attention map(可用于可视化)
self.attn_weights = attn_weights.cpu().detach()
# 合并头
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
return self.o_proj(attn_output)
# %% [markdown]
# 3. 前馈网络 + 残差连接
# %%
class YiziBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
super().__init__()
self.attn = YiziAttention(hidden_size, num_heads)
self.norm1 = nn.LayerNorm(hidden_size)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# 注意力残差连接
x = x + self.attn(self.norm1(x), freqs_cis)
# MLP 残差连接
x = x + self.mlp(self.norm2(x))
return x
# %% [markdown]
# 4. 主模型:YiziLM
# %%
class YiziLMConfig:
def __init__(
self,
vocab_size: int = 30000,
hidden_size: int = 512,
num_hidden_layers: int = 6,
num_attention_heads: int = 8,
max_position_embeddings: int = 8192,
intermediate_size: int = 2048
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.intermediate_size = intermediate_size
class YiziLM(nn.Module):
def __init__(self, config: YiziLMConfig):
super().__init__()
self.config = config
# 词嵌入
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_tokens.weight.data.normal_(mean=0.0, std=0.02) # 标准初始化
# 预计算所有位置的 RoPE 编码,并注册为 buffer(随模型保存)
freqs_cis = precompute_freqs_cis(
hidden_size=config.hidden_size,
max_seq_len=config.max_position_embeddings,
base=10000,
num_attention_heads=config.num_attention_heads
)
self.register_buffer("freqs_cis", freqs_cis) # 注册为 buffer
# Transformer 层堆叠
self.layers = nn.ModuleList([
YiziBlock(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size
) for _ in range(config.num_hidden_layers)
])
# 输出归一化与解码头
self.norm = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 权重共享(可选):词嵌入与 lm_head 共享参数
# 共享权重时也需初始化一致
self.lm_head.weight = self.embed_tokens.weight
# 注意:如果需要更强控制,可在训练中 detach 或缩放梯度
def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
inputs_embeds = self.embed_tokens(input_ids)
bsz, seq_len, _ = inputs_embeds.shape
# 获取当前位置对应的 freqs_cis
freqs_cis = self.freqs_cis[:seq_len] # (S', D)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(hidden_states, freqs_cis)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
output = {"logits": logits}
if labels is not None:
loss = nn.CrossEntropyLoss(ignore_index=-100)(
logits.view(-1, logits.size(-1)),
labels.view(-1))
output["loss"] = loss
return output
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_k: int = 50,
eos_token_id: Optional[int] = None
) -> torch.LongTensor:
"""
自回归生成文本
"""
for _ in range(max_new_tokens):
logits = self.forward(input_ids)["logits"]
next_token_logits = logits[:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(next_token_logits, top_k)
pivot = v[:, [-1]]
next_token_logits = torch.where(
next_token_logits < pivot,
torch.full_like(next_token_logits, -float('inf')),
next_token_logits
)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None and next_token.item() == eos_token_id:
break
return input_ids
# %% [markdown]
# 5. 初始化模型
# %%
# 创建配置
config = YiziLMConfig(
vocab_size=30000,
hidden_size=512,
num_hidden_layers=6,
num_attention_heads=8,
max_position_embeddings=8192,
intermediate_size=2048
)
# 构建模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = YiziLM(config).to(device)
# 参数总量统计
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")
# %% [markdown]
# 6. 数据生成
# %%
def generate_arithmetic_batch(batch_size=32, seq_len=16, start_range=(300, 1000)):
# 确保不会进入负数区域
starts = torch.randint(*start_range, (batch_size,))
sequences = []
for s in starts:
seq = list(range(s, s - seq_len * 10, -10)) # 减 10,共 seq_len 项
sequences.append(seq)
input_ids = torch.tensor(sequences).to(device)
# labels: 预测下一个 token
labels = input_ids.clone()
labels = torch.roll(labels, shifts=-1, dims=1)
labels[:, -1] = -100 # 忽略最后一个 loss
return input_ids, labels
# %% [markdown]
# 7. 验证函数
# %%
@torch.no_grad()
def validate_and_test(model, device, test_range=(350, 400)):
"""
在训练时未覆盖的数字范围测试模型表现
目的:区分「记忆」vs「归纳」
"""
model.eval()
correct = 0
total = 0
for x in range(*test_range):
inp = torch.tensor([[x]], device=device)
logits = model(inp)["logits"]
pred_next = logits[0, -1].argmax().item()
if pred_next == x - 10:
correct += 1
total += 1
accuracy = correct / total
print(f"🔍 Validation Accuracy on unseen {test_range}: {accuracy:.2%} ({correct}/{total})")
return accuracy > 0.95
# %% [markdown]
# 8. 模型训练
# %%
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4) # 初始学习率适中
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-5) # 学习率衰减
print("🚀 开始训练...")
for step in range(100):
input_ids, labels = generate_arithmetic_batch(batch_size=32, seq_len=50)
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 更严格裁剪
optimizer.step()
optimizer.zero_grad()
scheduler.step() # 更新学习率
if step % 50 == 0:
print(f"Step {step} - Loss: {loss.item():.4f}, Grad Norm: {grad_norm:.4f}")
# 每 50 步验证一次泛化能力
if validate_and_test(model, device):
print("🎉 模型已成功学会 '-10' 规则!停止训练。")
break
# %% [markdown]
# 9. 测试
# %%
def test_one_step(model, current_value):
model.eval()
# 构造完整的等差数列:current_value, c-10, c-20, ..., c-490 (共 50 项)
context = list(range(current_value, current_value - 50*10, -10))
input_ids = torch.tensor([context]).to(next(model.parameters()).device)
with torch.no_grad():
logits = model(input_ids=input_ids).logits
pred = logits[0, -1].argmax().item() # 最后位置预测
expected = current_value - 500 # 因为最后一个是 current_value - 490,所以下一个是 -500
print(f"✅ {current_value} -> {pred} | Expected: {expected} {'✔️' if pred == expected else '❌'}")
return pred
print("\n🧪 最终测试:")
for val in [100, 150, 300, 400]:
test_one_step(model, val)
# 可视化第一层 attention(可选)
attn_weights = model.layers[0].attn.attn_weights
if attn_weights is not None:
import matplotlib.pyplot as plt
plt.figure(figsize=(6,6))
plt.imshow(attn_weights[0].mean(0), cmap='viridis') # 平均所有头
plt.title("Average Attention Map (Layer 1)")
plt.xlabel("Key Position"); plt.ylabel("Query Position")
plt.colorbar()
plt.show()
# %% [markdown]
# 10. 保存模型
# %%
torch.save(model.state_dict(), "yizilm_70m.pth")
还有哪些地方需要改
最新发布