# %%
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple, Dict, List
import matplotlib.pyplot as plt
import numpy as np
import random
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
# %% [markdown]
# 1. RoPE: 旋转位置编码(完全自实现)
# %%
def precompute_freqs_cis(
hidden_size: int,
max_seq_len: int,
base: int = 10000,
num_attention_heads: int = 8
) -> torch.Tensor:
"""
预计算复数形式的旋转位置编码 (RoPE)
Args:
hidden_size: 模型维度
max_seq_len: 最大序列长度
base: 频率基数
num_attention_heads: 注意力头数
Returns:
freqs_cis: complex tensor of shape (max_seq_len, head_dim // 2)
"""
head_dim = hidden_size // num_attention_heads
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len) # 位置索引 [0, 1, ..., S-1]
freqs = torch.einsum("s,d->sd", t, inv_freq) # (S, D//2)
# 转为复数角度:cosθ + i*sinθ
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # shape: (max_seq_len, head_dim // 2)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
在最后一个维度上,将后半部分移到前面并取负。
例如:x = [x0, x1, x2, x3] -> [-x2, -x3, x0, x1]
这对应于乘以 i 的操作(复数旋转90度)
"""
x1, x2 = x.chunk(2, dim=-1) # 分成前后两半
return torch.cat((-x2, x1), dim=-1) # 后半取负放前
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
freqs_cis: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
将 RoPE 应用于 query 和 key 张量。
Args:
q: (bsz, n_heads, seq_len, head_dim)
k: (bsz, n_heads, seq_len, head_dim)
freqs_cis: (seq_len, head_dim // 2)
Returns:
q_embed, k_embed
"""
# 提取实部和虚部,并扩展维度以便广播
cos = freqs_cis.real.view(1, 1, -1, 1) # (1, 1, S, 1)
sin = freqs_cis.imag.view(1, 1, -1, 1)
# 使用标准公式: q * cos + rotate_half(q) * sin
q_out = (q * cos[:, :, :q.size(2), :]) + (_rotate_half(q) * sin[:, :, :q.size(2), :])
k_out = (k * cos[:, :, :k.size(2), :]) + (_rotate_half(k) * sin[:, :, :k.size(2), :])
return q_out.type_as(q), k_out.type_as(k)
# %% [markdown]
# 2. Hybrid Embedding Layer (实现字符与数值的综合判别)
# %%
class HybridEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
max_value: int = 10000,
value_offset: int = 30000
):
super().__init__()
self.hidden_size = hidden_size
self.max_value = max_value
self.value_offset = value_offset
# 1. 标准词表 embedding(用于语言)
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.word_embeddings.weight.data.normal_(mean=0.0, std=0.02)
# 2. 数值感知 embedding(用于真实数值)
self.value_embedding = nn.Linear(1, hidden_size)
# 初始化更稳定
nn.init.xavier_uniform_(self.value_embedding.weight)
if self.value_embedding.bias is not None:
nn.init.zeros_(self.value_embedding.bias)
# 3. 类型标识:0=word, 1=value
self.type_embedding = nn.Embedding(2, hidden_size)
# 4. 位置编码
self.position_embeddings = nn.Embedding(8192, hidden_size)
self.position_embeddings.weight.data.normal_(mean=0.0, std=0.02)
self.dropout = nn.Dropout(0.1)
def forward(
self,
input_ids: torch.LongTensor,
value_mask: Optional[torch.BoolTensor] = None,
position_ids: Optional[torch.LongTensor] = None
) -> torch.Tensor:
"""
input_ids: (B, S)
- < 30000: 正常词汇
- >=30000: 表示数值,减去 offset 得到真实数值
value_mask: (B, S), bool, True 表示该位置是数值
"""
batch_size, seq_length = input_ids.shape
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
embeddings = torch.zeros(batch_size, seq_length, self.hidden_size, device=input_ids.device)
# ==== Step 1: 处理普通词汇 ====
word_mask = ~value_mask if value_mask is not None else (input_ids < self.value_offset)
if word_mask.any():
word_embs = self.word_embeddings(input_ids[word_mask])
embeddings[word_mask] = word_embs
# ==== Step 2: 处理数值 ====
if value_mask is not None and value_mask.any():
raw_values = input_ids[value_mask].float() - self.value_offset # 解码真实值
values = raw_values.unsqueeze(-1) # (N, 1)
value_embs = torch.tanh(self.value_embedding(values)) # 稳定激活
embeddings[value_mask] = value_embs
# ==== Step 3: 加入类型和位置信息 ====
type_ids = torch.where(value_mask, 1, 0) if value_mask is not None \
else torch.zeros_like(input_ids)
type_embs = self.type_embedding(type_ids)
pos_embs = self.position_embeddings(position_ids)
embeddings = embeddings + type_embs + pos_embs
embeddings = self.dropout(embeddings)
return embeddings
# %% [markdown]
# 3. 自定义注意力层
# %%
class YiziAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
# 用于调试:存储最后一次 attention 权重
self.attn_weights = None
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
bsz, seq_len, _ = x.shape
# 投影到 QKV
query_states = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用 RoPE
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freqs_cis)
# Scaled Dot-Product Attention
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
attn_weights = torch.softmax(attn_weights, dim=-1)
# 存储 attention map(可用于可视化)
self.attn_weights = attn_weights.detach()
# 合并头
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
return self.o_proj(attn_output)
# %% [markdown]
# 4. 前馈网络 + 残差连接
# %%
class YiziBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
super().__init__()
self.attn = YiziAttention(hidden_size, num_heads)
self.norm1 = nn.LayerNorm(hidden_size)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# 注意力残差连接
x = x + self.attn(self.norm1(x), freqs_cis)
# MLP 残差连接
x = x + self.mlp(self.norm2(x))
return x
# %% [markdown]
# 5. 主模型:YiziLM
# %%
class YiziLMConfig:
def __init__(
self,
vocab_size: int = 30000,
hidden_size: int = 512,
num_hidden_layers: int = 6,
num_attention_heads: int = 8,
max_position_embeddings: int = 8192,
intermediate_size: int = 2048,
value_offset: int = 30000
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.intermediate_size = intermediate_size
self.value_offset = value_offset
class YiziLM(nn.Module):
def __init__(self, config: YiziLMConfig):
super().__init__()
self.config = config
self.value_head = nn.Linear(config.hidden_size, 1) # 输出 1 维实数
self.value_mean = 500.0 # 假设数据集中在 100~999
self.value_scale = 500.0 # 归一化因子
nn.init.xavier_uniform_(self.value_head.weight, gain=1e-3)
nn.init.zeros_(self.value_head.bias)
# 使用 Hybrid Embedding 替代传统嵌入
self.embed_tokens = HybridEmbedding(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
max_value=10000,
value_offset=config.value_offset
)
# # 词嵌入
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# 预计算所有位置的 RoPE 编码,并注册为 buffer(随模型保存)
freqs_cis = precompute_freqs_cis(
hidden_size=config.hidden_size,
max_seq_len=config.max_position_embeddings,
base=10000,
num_attention_heads=config.num_attention_heads
)
self.register_buffer("freqs_cis", freqs_cis) # 注册为 buffer
# Transformer 层堆叠
self.layers = nn.ModuleList([
YiziBlock(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size
) for _ in range(config.num_hidden_layers)
])
# 输出归一化
self.norm = nn.LayerNorm(config.hidden_size)
# === 双头输出 ===
# 头 1: 语言 token 分类 (0 ~ vocab_size-1)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.embed_tokens.word_embeddings.weight # 共享权重
# 头 2: 数值回归 head —— 输出一个 scalar(预测数值)
self.value_head = nn.Linear(config.hidden_size, 1) # 输出 1 维实数
# self.embed_tokens.weight.data.normal_(mean=0.0, std=0.02) # 标准初始化
# 初始化
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.LongTensor,
value_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.LongTensor] = None
) -> Dict[str, torch.Tensor]:
# 获取当前序列长度对应的 freqs_cis
seq_len = input_ids.size(1)
freqs_cis = self.freqs_cis[:seq_len]
# 使用 hybrid embedding
inputs_embeds = self.embed_tokens(input_ids, value_mask)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(hidden_states, freqs_cis)
hidden_states = self.norm(hidden_states) # (B, S, D)
# === 双头预测 ===
word_logits = self.lm_head(hidden_states) # (B, S, V)
value_pred = self.value_head(hidden_states).squeeze(-1) # (B, S)
# 我们不直接返回 value_pred 到 logits,而是用于特殊位置训练
output = {}
if labels is not None and value_mask is not None:
total_loss = 0.0
n_losses = 0
# ==== 1. 普通 token 损失:交叉熵 ====
word_positions = ~(value_mask) # 非数值位置
if word_positions.any():
word_labels = labels[word_positions]
# 过滤掉 ignore_index == -100
valid_mask = (word_labels != -100)
if valid_mask.any():
pred = word_logits[word_positions][valid_mask]
target = word_labels[valid_mask]
# 确保 target 在 [0, vocab_size)
target = torch.clamp(target, 0, self.config.vocab_size - 1)
loss_word = nn.functional.cross_entropy(pred, target)
total_loss += loss_word
n_losses += 1
# ==== 2. 数值位置损失:回归 MSE Loss ====
value_positions = value_mask & (labels != -100) # 数值且非 padding
if value_positions.any():
# 提取真实值并归一化
raw_values = labels[value_positions].float() - self.config.value_offset
normalized_values = (raw_values - self.value_mean) / self.value_scale
pred_normalized = value_pred[value_positions]
loss_value = nn.functional.mse_loss(pred_normalized, normalized_values)
total_loss += loss_value
n_losses += 1
loss = total_loss / n_losses if n_losses > 0 else total_loss
output["loss"] = loss
# 返回两个 logits 方便生成时决策
output["word_logits"] = word_logits
output["value_pred"] = value_pred # 回归值(未离散化)
return output
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
value_mask: Optional[torch.BoolTensor] = None,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_k: int = 50,
eos_token_id: Optional[int] = None,
strategy: str = "hybrid" # 'hybrid', 'force_value', 'force_word'
) -> Tuple[torch.LongTensor, torch.BoolTensor]:
"""
自回归生成,支持 hybrid 输出策略
"""
for _ in range(max_new_tokens):
outputs = self.forward(input_ids, value_mask)
word_logits = outputs["word_logits"][:, -1, :] # (B, V)
value_pred_raw = outputs["value_pred"][:, -1] # (B,)
next_token_id = None
is_value = False
if strategy == "hybrid":
# 启发式判断:如果前一个是数值,则继续预测数值
if value_mask is not None and value_mask[0, -1]:
# 前一个是数值 → 当前也预测数值
predicted_value = round(value_pred_raw.item())
next_token_id = predicted_value + self.config.value_offset
is_value = True
else:
# 否则走语言路径
logit = word_logits / temperature
if top_k > 0:
v, _ = torch.topk(logit, min(top_k, logit.size(-1)))
pivot = v[:, [-1]]
logit = torch.where(logit < pivot, -float('inf'), logit)
prob = torch.softmax(logit, dim=-1)
next_token_id = torch.multinomial(prob, num_samples=1).item()
is_value = False
elif strategy == "force_value":
predicted_value = round(value_pred_raw.item())
next_token_id = predicted_value + self.config.value_offset
is_value = True
else:
logit = word_logits / temperature
if top_k > 0:
v, _ = torch.topk(logit, min(top_k, logit.size(-1)))
pivot = v[:, [-1]]
logit = torch.where(logit < pivot, -float('inf'), logit)
prob = torch.softmax(logit, dim=-1)
next_token_id = torch.multinomial(prob, num_samples=1).item()
is_value = False
# 构造新 token
new_token = torch.tensor([[next_token_id]], device=input_ids.device)
input_ids = torch.cat([input_ids, new_token], dim=1)
new_is_value = torch.tensor([[is_value]], dtype=torch.bool, device=value_mask.device)
value_mask = torch.cat([value_mask, new_is_value], dim=1)
if eos_token_id is not None and next_token_id == eos_token_id:
break
return input_ids, value_mask
# %% [markdown]
# 6. 数据生成
# %%
# 词汇表模拟(简化版)
word_to_id = {
"the": 100, "next": 101, "number": 102, "after": 103, "is": 104,
"predict": 105, "?": 106, "[EOS]": 107
}
UNK_ID = 108
VALUE_OFFSET = 30000
def encode_input(tokens: List) -> Tuple[torch.Tensor, torch.BoolTensor]:
"""
将混合 tokens 编码为 input_ids 和 value_mask
支持字符串或整数
"""
input_ids = []
value_mask = []
for t in tokens:
if isinstance(t, str):
tid = word_to_id.get(t.lower(), UNK_ID)
input_ids.append(tid)
value_mask.append(False)
elif isinstance(t, int):
raw_id = t + VALUE_OFFSET
input_ids.append(raw_id)
value_mask.append(True)
return (
torch.tensor([input_ids], dtype=torch.long),
torch.tensor([value_mask], dtype=torch.bool)
)
def generate_arithmetic_batch_v2(
batch_size: int = 32,
context_len: int = 4,
pred_len: int = 1,
num_range: tuple = (100, 999),
step_candidates: list = None
):
"""
生成带有自然语言提示的等差数列预测任务
输入:["the", "sequence", 260, 270, 280, "?"]
输出:label = [?, 260, 270, 280, 290, -100]
"""
if step_candidates is None:
step_candidates = [1, -1, 2, -2, 5, -5, 10, -10]
all_input_ids = []
all_labels = []
all_value_masks = []
for _ in range(batch_size):
start = random.randint(*num_range)
step = random.choice(step_candidates)
seq = [start + i * step for i in range(context_len)]
if any(x < 0 or x > 9999 for x in seq):
continue
# 构造带语言提示的输入
prompt = ["the", "next", "after"] + seq[:-1] + ["?"]
target = seq[-1]
input_ids, value_mask = encode_input(prompt)
label_ids = input_ids.clone()
label_ids[0, -1] = target + VALUE_OFFSET # 最后一个 ? 替换为答案
label_ids = torch.roll(label_ids, shifts=-1, dims=1) # 左移一位作为 label
label_ids[0, -1] = -100 # 忽略最后 padding
all_input_ids.append(input_ids[0])
all_labels.append(label_ids[0])
all_value_masks.append(value_mask[0])
return (
torch.stack(all_input_ids),
torch.stack(all_labels),
torch.stack(all_value_masks)
)
# %% [markdown]
# 7. 初始化模型
# %%
config = YiziLMConfig(
vocab_size=30000,
hidden_size=512,
num_hidden_layers=4,
num_attention_heads=8,
intermediate_size=2048,
max_position_embeddings=8192,
value_offset=VALUE_OFFSET
)
model = YiziLM(config)
total_params = sum(p.numel() for p in model.parameters())
print(f"🚀 Total Parameters: {total_params:,}")
# %% [markdown]
# 8. 模型训练
# %%
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-5)
print("🔥 开始训练...")
for step in range(500):
input_ids, labels, value_mask = generate_arithmetic_batch_v2(batch_size=16)
outputs = model(input_ids=input_ids, value_mask=value_mask, labels=labels)
loss = outputs["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
if step % 50 == 0:
print(f"Step {step} | Loss: {loss.item():.4f} | Grad Norm: {grad_norm:.4f}")
# %% [markdown]
# 9. 测试和推理
# %%
@torch.no_grad()
def test_predict_next(current_seq: List[int]):
model.eval()
prompt = ["the", "next", "after"] + current_seq + ["?"]
input_ids, value_mask = encode_input(prompt)
generated, _ = model.generate(
input_ids=input_ids,
value_mask=value_mask,
max_new_tokens=1,
temperature=0.1,
top_k=10
)
pred_token = generated[0, -1].item()
if pred_token >= VALUE_OFFSET:
predicted_value = pred_token - VALUE_OFFSET
expected = current_seq[-1] + (current_seq[-1] - current_seq[-2]) # 等差推断
print(f"🎯 Input: {current_seq} → Predicted: {predicted_value}, Expected: {expected} {'✔️' if predicted_value == expected else '❌'}")
else:
word = [k for k, v in word_to_id.items() if v == pred_token]
print(f"🔤 Word Output: {word}")
print("\n🧪 测试结果:")
test_predict_next([260, 270, 280])
test_predict_next([500, 490, 480])
test_predict_next([101, 103, 105])
# %% [markdown]
# 9. 保存模型
# %%
save_dict = {
'model_state_dict': model.state_dict(),
'config': config.__dict__,
'word_to_id': word_to_id,
'value_offset': VALUE_OFFSET
}
torch.save(save_dict, "yizilm_hybrid_numerical.pth")
print("💾 模型已保存至 yizilm_hybrid_numerical.pth")
# %% [markdown]
# 10. 可视化注意力
# %%
attn_weights = model.layers[0].attn.attn_weights
if attn_weights is not None and attn_weights.ndim == 4:
plt.figure(figsize=(6, 6))
avg_attn = attn_weights[0].mean(0).cpu().numpy() # average over heads
plt.imshow(avg_attn, cmap='viridis', aspect='auto')
plt.title("Average Attention Map (Layer 1)")
plt.xlabel("Key Position")
plt.ylabel("Query Position")
plt.colorbar()
plt.tight_layout()
plt.show()
总结问题并解决,修改后的代码给出完整的注释
最新发布