🎯 聊天机器人(Mini GPT)
🤖 构建一个简化版 ChatGPT,支持自然语言多轮对话、生成回复、上下文记忆。
本项目基于 GPT-style Decoder-only Transformer,采用自回归生成策略,并引入 Top-k
/ Top-p
控制生成质量。
✅ 项目亮点
- 使用 GPT 单向解码器结构(支持文本生成)
- 支持
Top-k
、Top-p
(Nucleus)采样生成 - 支持多轮上下文拼接(可选记忆窗口)
- 支持停止词、最大长度等生成安全控制
- 可部署为对话接口 / 微信机器人 / Gradio 小工具
📚 数据集推荐
✅ 数据格式示例(JSON)
{
"history": ["你好啊", "你好!今天过得怎么样?"],
"reply": "挺不错的,谢谢你~"
}
🧠 模型结构设计(Mini GPT)
class MiniGPT(tf.keras.Model):
def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=4, max_len=128):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.pos_encoding = positional_encoding(max_len, d_model)
self.decoder_layers = [DecoderLayer(d_model, num_heads, d_model * 4) for _ in range(num_layers)]
self.final = tf.keras.layers.Dense(vocab_size)
def call(self, x, look_ahead_mask):
x = self.embedding(x) + self.pos_encoding[:, :tf.shape(x)[1], :]
for layer in self.decoder_layers:
x = layer(x, None, look_ahead_mask, None)
return self.final(x)
- 注意:无 Encoder,仅 Decoder
- Mask 采用
look-ahead mask
以防止看到未来 token - 输入为拼接好的上下文对话序列(例如 “用户:xx 模型:xx 用户:…”)
🛠️ 训练流程(自回归)
- 输入:
[token_1, token_2, ..., token_n-1]
- 标签:
[token_2, ..., token_n]
- 损失:
SparseCategoricalCrossentropy(from_logits=True)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
🧪 推理流程(Greedy / Top-k / Top-p)
✅ Greedy Search
def generate_response(input_ids, model, max_len=40):
for _ in range(max_len):
logits = model(tf.constant([input_ids]), look_ahead_mask=None)
next_token = tf.argmax(logits[:, -1, :], axis=-1).numpy()[0]
input_ids.append(next_token)
if next_token == tokenizer.word_index.get('<end>'):
break
return tokenizer.decode(input_ids)
✅ Top-k / Top-p 采样
import numpy as np
def sample_from_logits(logits, top_k=50, top_p=0.9):
logits = logits[-1]
logits = tf.nn.softmax(logits).numpy()
if top_k > 0:
top_k_indices = np.argpartition(logits, -top_k)[-top_k:]
top_k_logits = logits[top_k_indices]
top_k_logits /= top_k_logits.sum()
return np.random.choice(top_k_indices, p=top_k_logits)
sorted_indices = np.argsort(logits)[::-1]
cumulative_probs = np.cumsum(logits[sorted_indices])
cutoff = sorted_indices[cumulative_probs <= top_p]
cutoff_probs = logits[cutoff] / logits[cutoff].sum()
return np.random.choice(cutoff, p=cutoff_probs)
💬 多轮对话状态管理
def build_input_sequence(history, tokenizer, max_len=128):
tokens = []
for utter in history:
tokens += tokenizer.encode(utter) + [tokenizer.sep_token_id]
return tokens[-max_len:]
✅ 示例:
用户:你是谁?
🤖:我是一个聊天机器人呀!
用户:你会讲笑话吗?
→ 输入模型的 token: ["你是谁", "<sep>", "我是一个聊天机器人呀!", "<sep>", "你会讲笑话吗?"]
🧼 安全控制策略
stop_token
:遇到 <end>
或特殊停用词即停止生成max_length
:设置最大 token 长度(如 40)repetition_penalty
:避免重复 token- 可添加关键词过滤 / 敏感词屏蔽等模块
🌐 部署建议
平台 | 推荐方式 | 说明 |
---|
Web 聊天 | Flask / FastAPI | WebSocket 或 REST |
微信机器人 | ItChat / 青云客 接入 | 私聊群聊皆可 |
TF.js | 极简小模型可行 | 不建议复杂多轮 |
Android | 不推荐 | 模型太大 |
📁 项目结构建议
mini-gpt-chatbot/
├── data/
│ └── persona_chat.jsonl
├── model.py
├── train.py
├── predict.py
├── app.py (接口服务)
├── tokenizer/
├── export/
│ └── saved_model/
└── requirements.txt
📋 示例对话(中文)
👤:你喜欢看电影吗?
🤖:当然啦,我最喜欢的是科幻片!
👤:你喜欢星际穿越吗?
🤖:超喜欢,那是我心中最经典的科幻片之一。
✅ 项目小结
内容 | 技术实现 |
---|
模型结构 | GPT-style Decoder |
生成机制 | 自回归 + Top-k / Top-p |
数据格式 | 多轮对话 JSONL |
训练机制 | Teacher Forcing + Mask Loss |
部署 | Flask / 轻量前端 UI |
控制机制 | 最大长度 / 停止词 / 重复惩罚 |