GPT-2 剪枝模型推理函数封装实战:输入输出结构与结果解析

部署运行你感兴趣的模型镜像

本文基于 prune_infer.py,讲解如何对剪枝后的 GPT-2 小模型进行高效推理调用,封装成 infer(prompt) 函数,支持多句输入、多轮验证,并详细解析 logits 输出结构与 token 解码方法。


🔧 环境配置与依赖说明

你需要准备以下内容:

  • 剪枝后的 GPT-2 student 模型(见第1篇)
  • 原始 tokenizer(无需剪枝)
  • 已安装 transformerstorch

📁 文件结构说明

.
├── prune_infer.py
├── gpt2_student_v2_pruned/       # 剪枝后模型
├── ../python3_distillation/gpt2_student_v2/  # tokenizer

1️⃣ 加载模型与分词器

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

model_path = "./gpt2_student_v2_pruned"
tokenizer_path = "../python3_distillation/gpt2_student_v2"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GPT2LMHeadModel.from_pretrained(model_path).to(device).eval()
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

说明:

  • model.eval():关闭 Dropout,进入推理模式
  • tokenizer 加载原始路径即可,剪枝不影响词表

2️⃣ 封装推理函数 infer()

def infer(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits  # shape: [B, L, V]
    last_token_id = int(logits[0, -1].argmax())
    last_token = tokenizer.decode([last_token_id], clean_up_tokenization_spaces=True).strip()

    return last_token_id, last_token

解释:

名称说明
logits每个 token 对 vocab 中所有词的打分
logits[:, -1, :]最后一个 token 的预测分布
argmax()找到得分最高的 token id
decode()将 token id 转换为字符串

3️⃣ 多 prompt 批量测试

prompts = [
    "Hello world",
    "The sky is",
    "I love",
    "Artificial intelligence is"
]

for i, prompt in enumerate(prompts, 1):
    token_id, token_text = infer(prompt)
    print(f"\n📝 Prompt {i}: {prompt}")
    print(f"🔹 Predicted Token ID: {token_id}")
    print(f"🔹 Predicted Token Text: {token_text}")

输出示例:

📝 Prompt: Hello world
🔹 Predicted Token: !

📝 Prompt: Artificial intelligence is
🔹 Predicted Token: transforming

4️⃣ 输出结构可视化理解

假设:

logits.shape = [1, 5, 50257]

说明:

  • Batch = 1(单条输入)
  • Sequence length = 5(token 数量)
  • Vocab size = 50257
logits[0, -1, :]  →  第 5 个词的 50257 个词概率分布

✅ 可拓展方向

  • 增加 Top-k / Top-p 采样
  • 生成多个词而非单个 token
  • 支持 batch 输入与响应格式封装
  • 接入 Flask 或 API 接口(见第1篇)

📌 总结

  • 剪枝模型的调用方式与原模型一致,兼容 Hugging Face 接口
  • 推理结构为 prompt → logits → token_id → decode
  • 可轻松封装成模块,支持网页调用、接口服务等二次开发

🧭 本系列 GPT-2 模型剪枝部署项目系列四部曲


📌 YoanAILab 技术导航页

💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页

📚 包含内容:

  • 🧠 GPT-2 项目源码(GitHub)
  • ✍️ 优快云 技术专栏合集
  • 💼 知乎转型日志
  • 📖 公众号 YoanAILab 全文合集

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YoanAILab

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值