# %%
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple, Dict
# %% [markdown]
# 1. RoPE: 旋转位置编码(完全自实现)
# %%
def precompute_freqs_cis(hidden_size: int, max_seq_len: int, base: int = 10000, num_attention_heads: int = 8) -> torch.Tensor:
"""
预计算复数形式的旋转位置编码 (RoPE)
Args:
hidden_size: 模型维度(如 512)
max_seq_len: 最大序列长度(如 8192)
base: 频率基数,默认 10000
Returns:
freqs_cis: complex tensor of shape (max_seq_len, head_dim // 2)
"""
head_dim = hidden_size // num_attention_heads
theta = torch.arange(0, head_dim, 2).float() / (-head_dim / math.log(base))
freqs = 1.0 / (torch.exp(theta)) # shape: (head_dim // 2,)
t = torch.arange(max_seq_len, device=freqs.device) # positions
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
# 转为复数:cosθ + i*sinθ
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis # shape: (max_seq_len, head_dim//2)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
将 RoPE 应用于 query 和 key 张量。
Args:
q: (bsz, n_heads, seq_len, head_dim)
k: (bsz, n_heads, seq_len, head_dim)
freqs_cis: (seq_len, head_dim // 2)
Returns:
q_embed, k_embed
"""
bsz, n_heads, seq_len, head_dim = q.shape
half_dim = head_dim // 2
# Reshape to (bsz, n_heads, seq_len, head_dim//2, 2)
def reshape_for_complex(tensor: torch.Tensor) -> torch.Tensor:
return tensor.view(bsz, n_heads, seq_len, half_dim, 2).float()
q_ = reshape_for_complex(q)
k_ = reshape_for_complex(k)
cos = freqs_cis.real
sin = freqs_cis.imag
# 然后扩展维度以便广播
cos = cos.view(1, 1, seq_len, head_dim//2) # → (1,1,10,32,1)
sin = sin.view(1, 1, seq_len, head_dim//2)
# 复数乘法: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
# 这里使用实数模拟复数运算
q_out_real = q_[..., 0] * cos - q_[..., 1] * sin
q_out_imag = q_[..., 0] * sin + q_[..., 1] * cos
q_out = torch.stack([q_out_real, q_out_imag], dim=-1).flatten(-2)
k_out_real = k_[..., 0] * cos - k_[..., 1] * sin
k_out_imag = k_[..., 0] * sin + k_[..., 1] * cos
k_out = torch.stack([k_out_real, k_out_imag], dim=-1).flatten(-2)
return q_out.type_as(q), k_out.type_as(k)
# %% [markdown]
# 2. 自定义注意力层
# %%
class YiziAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
bsz, seq_len, _ = x.shape
# 投影到 QKV
query_states = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 应用 RoPE
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freqs_cis)
# Scaled Dot-Product Attention
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value_states)
# 合并头
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
return self.o_proj(attn_output)
# %% [markdown]
# 3. 前馈网络 + 残差连接
# %%
class YiziBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
super().__init__()
self.attn = YiziAttention(hidden_size, num_heads)
self.norm1 = nn.LayerNorm(hidden_size)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
# 注意力残差连接
x = x + self.attn(self.norm1(x), freqs_cis)
# MLP 残差连接
x = x + self.mlp(self.norm2(x))
return x
# %% [markdown]
# 4. 主模型:YiziLM
# %%
class YiziLMConfig:
def __init__(
self,
vocab_size: int = 30000,
hidden_size: int = 512,
num_hidden_layers: int = 6,
num_attention_heads: int = 8,
max_position_embeddings: int = 8192,
intermediate_size: int = 2048
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.intermediate_size = intermediate_size
class YiziLM(nn.Module):
def __init__(self, config: YiziLMConfig):
super().__init__()
self.config = config
# 词嵌入
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# 预计算所有位置的 RoPE 编码,并注册为 buffer(随模型保存)
freqs_cis = precompute_freqs_cis(
hidden_size=config.hidden_size,
max_seq_len=config.max_position_embeddings,
base=10000,
num_attention_heads=config.num_attention_heads
)
self.register_buffer("freqs_cis", freqs_cis)
# Transformer 层堆叠
self.layers = nn.ModuleList([
YiziBlock(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size
) for _ in range(config.num_hidden_layers)
])
# 输出归一化与解码头
self.norm = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 权重共享(可选):词嵌入与 lm_head 共享参数
self.lm_head.weight = self.embed_tokens.weight
def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
inputs_embeds = self.embed_tokens(input_ids)
bsz, seq_len, _ = inputs_embeds.shape
# 获取当前位置对应的 freqs_cis
freqs_cis = self.freqs_cis[:seq_len]
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(hidden_states, freqs_cis)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
output = {"logits": logits}
if labels is not None:
loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
output["loss"] = loss
return output
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_k: int = 50,
eos_token_id: Optional[int] = None
) -> torch.LongTensor:
"""
自回归生成文本
"""
for _ in range(max_new_tokens):
logits = self.forward(input_ids)["logits"]
next_token_logits = logits[:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(next_token_logits, top_k)
pivot = v[:, [-1]]
next_token_logits = torch.where(next_token_logits < pivot, torch.full_like(next_token_logits, -float('inf')), next_token_logits)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None and next_token.item() == eos_token_id:
break
return input_ids
# %%
# 创建配置
config = YiziLMConfig(
vocab_size=30000,
hidden_size=512,
num_hidden_layers=6,
num_attention_heads=8,
max_position_embeddings=8192
)
# 构建模型
model = YiziLM(config)
# 输入数据
input_ids = torch.randint(0, 30000, (2, 10)) # batch=2, seq_len=10
labels = input_ids.clone()
# 训练模式
output = model(input_ids=input_ids, labels=labels)
print("Loss:", output["loss"].item())
print("Logits shape:", output["logits"].shape)
# 推理生成
prompt = torch.tensor([[100, 200, 300]]) # 初始 token
generated = model.generate(prompt, max_new_tokens=20, eos_token_id=2)
print("Generated:", generated[0].tolist())
# %%
torch.save(model.state_dict(), "yizilm_70m.pth")
这是现在的代码
最新发布