本文基于
prune_infer.py,讲解如何对剪枝后的 GPT-2 小模型进行高效推理调用,封装成infer(prompt)函数,支持多句输入、多轮验证,并详细解析 logits 输出结构与 token 解码方法。
🔧 环境配置与依赖说明
你需要准备以下内容:
- 剪枝后的 GPT-2 student 模型(见第1篇)
- 原始 tokenizer(无需剪枝)
- 已安装
transformers和torch
📁 文件结构说明
.
├── prune_infer.py
├── gpt2_student_v2_pruned/ # 剪枝后模型
├── ../python3_distillation/gpt2_student_v2/ # tokenizer
1️⃣ 加载模型与分词器
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
model_path = "./gpt2_student_v2_pruned"
tokenizer_path = "../python3_distillation/gpt2_student_v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT2LMHeadModel.from_pretrained(model_path).to(device).eval()
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
说明:
model.eval():关闭 Dropout,进入推理模式tokenizer加载原始路径即可,剪枝不影响词表
2️⃣ 封装推理函数 infer()
def infer(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # shape: [B, L, V]
last_token_id = int(logits[0, -1].argmax())
last_token = tokenizer.decode([last_token_id], clean_up_tokenization_spaces=True).strip()
return last_token_id, last_token
解释:
| 名称 | 说明 |
|---|---|
logits | 每个 token 对 vocab 中所有词的打分 |
logits[:, -1, :] | 最后一个 token 的预测分布 |
argmax() | 找到得分最高的 token id |
decode() | 将 token id 转换为字符串 |
3️⃣ 多 prompt 批量测试
prompts = [
"Hello world",
"The sky is",
"I love",
"Artificial intelligence is"
]
for i, prompt in enumerate(prompts, 1):
token_id, token_text = infer(prompt)
print(f"\n📝 Prompt {i}: {prompt}")
print(f"🔹 Predicted Token ID: {token_id}")
print(f"🔹 Predicted Token Text: {token_text}")
输出示例:
📝 Prompt: Hello world
🔹 Predicted Token: !
📝 Prompt: Artificial intelligence is
🔹 Predicted Token: transforming
4️⃣ 输出结构可视化理解
假设:
logits.shape = [1, 5, 50257]
说明:
- Batch = 1(单条输入)
- Sequence length = 5(token 数量)
- Vocab size = 50257
logits[0, -1, :] → 第 5 个词的 50257 个词概率分布
✅ 可拓展方向
- 增加 Top-k / Top-p 采样
- 生成多个词而非单个 token
- 支持 batch 输入与响应格式封装
- 接入 Flask 或 API 接口(见第1篇)
📌 总结
- 剪枝模型的调用方式与原模型一致,兼容 Hugging Face 接口
- 推理结构为
prompt → logits → token_id → decode - 可轻松封装成模块,支持网页调用、接口服务等二次开发
🧭 本系列 GPT-2 模型剪枝部署项目系列四部曲
- 🚀 第1篇:GPT-2 小模型剪枝实战:L1 Unstructured 剪枝策略与实现详解
- 🌐 第2篇:GPT-2 剪枝模型推理函数封装实战:输入输出结构与结果解析
- 🧠 第3篇:GPT-2 剪枝前后性能对比实测:加速效果与输出一致性全分析
- 💼 第4篇:GPT-2 Student 模型剪枝部署实战:Flask 接口封装与服务调用指南
📌 YoanAILab 技术导航页
💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页
📚 包含内容:
- 🧠 GPT-2 项目源码(GitHub)
- ✍️ 优快云 技术专栏合集
- 💼 知乎转型日志
- 📖 公众号 YoanAILab 全文合集

被折叠的 条评论
为什么被折叠?



