本文对比两种 GPT-2 蒸馏小模型的推理实现(v1 和 v2),展示如何从本地加载模型、进行 prompt 推理、提取 Top-1 token 输出,并通过 Python 封装 infer() 函数以便后续 API 封装与前端调用。
📎 两个版本简介
| 项目文件 | 模型来源 | 特点 |
|---|---|---|
| distill_infer.py | student from v1 | 原始 GPT-2 结构,参数较多 |
| distill_infer_v2.py | student from v2 | 使用 GPT2Config 构建的 6 层小模型 |
两者调用结构类似,仅模型大小与来源不同
1️⃣ 模型加载流程(通用)
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
MODEL_DIR = "./gpt2_student_v2" # v1: gpt2_student
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🚀 加载小 Student 模型中...")
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR).to(device).eval()
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
print("✅ 模型加载完成")
2️⃣ 推理函数封装:支持单条或多条 prompt
def infer(prompts, max_length=30):
if isinstance(prompts, str):
prompts = [prompts] # 单条输入转列表
results = []
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # shape: [batch, seq_len, vocab_size]
last_logits = logits[:, -1, :] # 取每句话最后一个 token 的预测
top1_token_id = torch.argmax(last_logits, dim=-1).item()
top1_token = tokenizer.decode([top1_token_id], clean_up_tokenization_spaces=True).strip()
results.append({
"prompt": prompt,
"top1_token_id": top1_token_id,
"top1_token": top1_token
})
return results
3️⃣ 推理输出结构解析
假设:
logits.shape = [2, 5, 50257]
说明:
- batch 中有 2 句话
- 每句话长度为 5 个 token
- 每个 token 的预测是对 50257 个词的打分
logits[:, -1, :] → shape: [2, 50257]
每行即为一句话最后 token 的预测分布。
4️⃣ 实战输出示例
test_prompts = [
"Hello world",
"The sky is",
"I love",
"Artificial intelligence is"
]
outputs = infer(test_prompts)
for i, output in enumerate(outputs):
print(f"📝 Prompt {i+1}: {output['prompt']}")
print(f"🔹 Token ID: {output['top1_token_id']}")
print(f"🔹 Token Text: {output['top1_token']}")
输出示例:
📝 Prompt: Hello world
🔹 Token: !
📝 Prompt: I love
🔹 Token: you
5️⃣ 对比分析:v1 vs v2 推理体验
| 维度 | v1(gpt2_student) | v2(gpt2_student_v2) |
|---|---|---|
| 模型结构 | 与 GPT-2 一致 | Transformer 缩减至 6 层 |
| 推理速度 | 较慢 | 明显更快 |
| 输出一致性 | 高 | 基本一致 |
| 文件大小 | 大(>500MB) | 小(≈250MB) |
结论:v2 更适合部署到轻量级服务或前端场景,而 v1 可用于教学与结构可视化用途。
📌 总结
- 推理核心逻辑是提取 logits 的最后一个位置,并取 Top-1 token 输出
- 封装为
infer(prompt)函数便于复用与服务化 v2版本 Student 小模型兼具速度与效果,推荐上线使用
🧭 本文为 GPT-2 蒸馏压缩项目第一篇,共3篇
- 🧩 第一篇:从蒸馏到压缩:两种方式训练GPT-2小模型 Student
- 🚀 第二篇:GPT-2 蒸馏模型推理实战:标准 Student vs 压缩 Student 的调用对比
- 🌐 第三篇:GPT-2 蒸馏小模型部署实战:Flask 封装推理接口与网页调用演示
📌 YoanAILab 技术导航页
💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页
📚 包含内容:
- 🧠 GPT-2 项目源码(GitHub)
- ✍️ 优快云 技术专栏合集
- 💼 知乎转型日志
- 📖 公众号 YoanAILab 全文合集
2017

被折叠的 条评论
为什么被折叠?



