GPT-2 API服务使用文档
基础信息
- API基础URL: http://localhost:8000
- 接口文档: http://localhost:8000/docs (Swagger UI)
- 备用文档: http://localhost:8000/redoc (ReDoc)
认证方式
当前版本API无需认证,生产环境建议添加以下认证方式之一:
- API Key认证
- OAuth2.0认证
- JWT令牌认证
可用接口
文本生成接口
请求: POST /generate
请求体:
{
"prompt": "你的提示词",
"max_new_tokens": 512,
"repetition_penalty": 1.1,
"temperature": 0.7,
"top_p": 0.9
}
响应:
{
"generated_text": "生成的文本内容",
"request_id": "唯一请求ID",
"processing_time": 0.5,
"model_info": {
"name": "GPT-2",
"type": "causal-language-model"
},
"input_tokens": 20,
"output_tokens": 150
}
健康检查接口
请求: GET /health
响应:
{
"status": "healthy",
"model_loaded": true,
"uptime": 3600.0,
"request_count": 100,
"error_count": 5,
"model_loading_time": 15.2,
"device": "cuda"
}
使用示例
Python请求示例
import requests
import json
url = "http://localhost:8000/generate"
headers = {"Content-Type": "application/json"}
data = {
"prompt": "写一篇关于人工智能发展历史的短文",
"max_new_tokens": 300,
"temperature": 0.8
}
response = requests.post(url, headers=headers, data=json.dumps(data))
result = response.json()
print("生成文本:", result["generated_text"])
print("处理时间:", result["processing_time"], "秒")
curl请求示例
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \
-d '{"prompt": "写一首关于春天的诗", "max_new_tokens": 100, "temperature": 0.7}'
错误码说明
| 状态码 | 说明 | 可能原因 |
|---|---|---|
| 200 | 请求成功 | 正常响应 |
| 400 | 请求参数错误 | 输入数据格式或值范围错误 |
| 422 | 验证错误 | 请求数据未通过验证 |
| 500 | 服务器内部错误 | 模型推理失败或系统错误 |
| 503 | 服务不可用 | 模型加载失败或服务正在重启 |
性能优化建议
- 批量处理: 对于大量短文本生成任务,建议批量处理
- 调整参数:
- 长文本生成:减小temperature,增大repetition_penalty
- 创意文本生成:增大temperature(0.8-1.2),启用top_p
- 缓存策略: 对于重复请求,实现结果缓存机制
- 资源分配: GPU环境下可增加batch_size提升吞吐量
故障排除
-
服务无法启动:
- 检查端口是否被占用
- 确认模型文件是否完整
- 检查依赖是否安装正确
-
推理速度慢:
- 检查是否使用了GPU加速
- 尝试减小max_new_tokens参数
- 关闭其他占用资源的进程
## 第四步:实现API测试脚本
创建`examples/test_api.py`文件,实现自动化测试:
```python
import requests
import time
import json
import unittest
from typing import Dict, Any
class TestGPT2API(unittest.TestCase):
BASE_URL = "http://localhost:8000"
def test_health_check(self):
"""测试健康检查接口"""
response = requests.get(f"{self.BASE_URL}/health")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(data["status"], "healthy")
self.assertTrue(data["model_loaded"])
self.assertGreater(data["uptime"], 0)
print("Health check passed. Server status:")
print(f" Model loaded: {data['model_loaded']}")
print(f" Uptime: {data['uptime']:.2f}s")
print(f" Request count: {data['request_count']}")
print(f" Error count: {data['error_count']}")
print(f" Device: {data['device']}")
def test_text_generation_basic(self):
"""测试基本文本生成功能"""
payload = {
"prompt": "Write a short paragraph about artificial intelligence.",
"max_new_tokens": 100,
"repetition_penalty": 1.1
}
start_time = time.time()
response = requests.post(f"{self.BASE_URL}/generate", json=payload)
processing_time = time.time() - start_time
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertIn("generated_text", data)
self.assertIn("request_id", data)
self.assertGreater(len(data["generated_text"]), 10)
self.assertIsInstance(data["processing_time"], float)
print("\nBasic text generation test passed:")
print(f" Request ID: {data['request_id']}")
print(f" Processing time: {data['processing_time']:.2f}s")
print(f" Input tokens: {data['input_tokens']}")
print(f" Output tokens: {data['output_tokens']}")
print(f" Generated text: {data['generated_text'][:100]}...")
def test_text_generation_with_parameters(self):
"""测试带参数的文本生成"""
test_cases = [
{
"name": "creative_mode",
"parameters": {
"prompt": "Write a creative story about space exploration.",
"max_new_tokens": 150,
"temperature": 1.2,
"top_p": 0.9,
"repetition_penalty": 1.0
}
},
{
"name": "precise_mode",
"parameters": {
"prompt": "Explain the concept of machine learning in simple terms.",
"max_new_tokens": 120,
"temperature": 0.5,
"top_p": 0.7,
"repetition_penalty": 1.2
}
}
]
for case in test_cases:
print(f"\nTesting {case['name']}...")
response = requests.post(
f"{self.BASE_URL}/generate",
json=case["parameters"]
)
self.assertEqual(response.status_code, 200)
data = response.json()
print(f" Generated text length: {len(data['generated_text'])} characters")
print(f" Processing time: {data['processing_time']:.2f}s")
print(f" First 100 characters: {data['generated_text'][:100]}...")
def test_input_validation(self):
"""测试输入验证功能"""
invalid_cases = [
{
"name": "empty_prompt",
"parameters": {"prompt": "", "max_new_tokens": 100},
"expected_status": 422
},
{
"name": "too_long_prompt",
"parameters": {"prompt": "a" * 3000, "max_new_tokens": 100},
"expected_status": 422
},
{
"name": "invalid_temperature",
"parameters": {"prompt": "Test", "max_new_tokens": 100, "temperature": 3.0},
"expected_status": 422
},
{
"name": "inappropriate_content",
"parameters": {"prompt": "Generate harmful content about", "max_new_tokens": 100},
"expected_status": 422
}
]
print("\nTesting input validation...")
for case in invalid_cases:
response = requests.post(
f"{self.BASE_URL}/generate",
json=case["parameters"]
)
self.assertEqual(
response.status_code,
case["expected_status"],
f"Case {case['name']} failed: expected {case['expected_status']}, "
f"got {response.status_code}"
)
print("All input validation tests passed")
def test_performance_measurement(self):
"""测试API性能"""
payload = {
"prompt": "Write a paragraph about the future of technology.",
"max_new_tokens": 200,
"temperature": 0.7
}
num_runs = 5
total_time = 0
tokens_per_second_list = []
print(f"\nTesting performance with {num_runs} runs...")
for i in range(num_runs):
start_time = time.time()
response = requests.post(f"{self.BASE_URL}/generate", json=payload)
end_time = time.time()
self.assertEqual(response.status_code, 200)
data = response.json()
run_time = end_time - start_time
total_time += run_time
tokens_per_second = data["output_tokens"] / run_time
tokens_per_second_list.append(tokens_per_second)
print(f" Run {i+1}: {run_time:.2f}s, {tokens_per_second:.2f} tokens/s")
avg_time = total_time / num_runs
avg_tokens_per_second = sum(tokens_per_second_list) / num_runs
print(f"\nPerformance summary:")
print(f" Average time per request: {avg_time:.2f}s")
print(f" Average tokens per second: {avg_tokens_per_second:.2f}")
if __name__ == "__main__":
# 等待服务启动(如果需要在CI/CD中自动运行)
print("Waiting for API service to start...")
time.sleep(10)
unittest.main(verbosity=2)
第五步:启动服务与测试验证
启动API服务
cd examples
python api_server.py
成功启动后,你将看到类似以下输出:
2025-09-17 00:11:35 - gpt2-api - INFO - Starting model loading...
2025-09-17 00:11:38 - gpt2-api - INFO - Model loaded successfully on cuda device
2025-09-17 00:11:39 - gpt2-api - INFO - Model loading completed in 4.23 seconds
INFO: Started server process [12345]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
访问自动生成的API文档
打开浏览器访问 http://localhost:8000/docs,你将看到交互式的Swagger UI文档,可以直接在界面上测试API。
运行自动化测试
在新的终端中运行:
cd examples
python test_api.py
测试脚本将自动验证API的各项功能,并输出性能指标。
高并发部署:从单实例到生产环境
使用Nginx作为反向代理
创建nginx.conf配置文件:
server {
listen 80;
server_name gpt2-api.example.com;
location / {
proxy_pass http://localhost:8000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# 限制请求速率
limit_req_zone $binary_remote_addr zone=gpt2_api:10m rate=10r/s;
location /generate {
proxy_pass http://localhost:8000;
limit_req zone=gpt2_api burst=20 nodelay;
}
# 启用gzip压缩
gzip on;
gzip_types application/json text/plain;
}
使用systemd管理服务
创建/etc/systemd/system/gpt2-api.service文件:
[Unit]
Description=GPT-2 API Service
After=network.target
[Service]
User=appuser
Group=appuser
WorkingDirectory=/data/web/disk1/git_repo/openMind/gpt2/examples
ExecStart=/data/web/disk1/git_repo/openMind/gpt2/.venv/bin/python api_server.py
Restart=on-failure
RestartSec=5s
Environment="PATH=/data/web/disk1/git_repo/openMind/gpt2/.venv/bin"
Environment="PYTHONUNBUFFERED=1"
[Install]
WantedBy=multi-user.target
启用并启动服务:
sudo systemctl enable gpt2-api
sudo systemctl start gpt2-api
实现多实例部署
创建start_multiple_instances.sh脚本:
#!/bin/bash
# 启动多个实例,使用不同端口
PORT=8000
NUM_INSTANCES=4
for ((i=0; i<NUM_INSTANCES; i++)); do
INSTANCE_PORT=$((PORT + i))
echo "Starting instance on port $INSTANCE_PORT"
nohup python api_server.py $INSTANCE_PORT > "logs/instance_$INSTANCE_PORT.log" 2>&1 &
done
echo "All $NUM_INSTANCES instances started"
性能优化与资源管理
GPU内存优化策略
| 优化方法 | 内存节省 | 性能影响 | 实现难度 |
|---|---|---|---|
| 半精度浮点数(FP16) | 50% | 轻微提升 | 简单 |
| 模型量化(Int8) | 75% | 轻微下降 | 中等 |
| 模型并行 | 按设备数分摊 | 轻微下降 | 复杂 |
| 梯度检查点 | 30-40% | 10-20%下降 | 中等 |
实现模型量化(修改加载代码)
# 量化模型加载示例
model = AutoModelForCausalLM.from_pretrained(
".",
device_map="auto",
load_in_8bit=True, # 启用8位量化
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
)
请求队列管理
添加请求队列机制,避免资源过载:
from fastapi import BackgroundTasks, Queue, Request, Depends
from fastapi.responses import JSONResponse
import asyncio
# 创建请求队列
request_queue = Queue(maxsize=100)
@app.post("/generate")
async def generate_text(
request: GenerationRequest,
background_tasks: BackgroundTasks,
queue: Queue = Depends(lambda: request_queue)
):
if queue.full():
raise HTTPException(status_code=429, detail="Too many requests. Please try again later.")
# 将请求加入队列
await queue.put(request)
try:
# 处理请求...
finally:
await queue.get()
queue.task_done()
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



