基于Google Cloud AI Platform的机器学习模型训练与部署实战
本文将通过一个出租车费用预测模型的案例,详细介绍如何将本地开发的TensorFlow模型迁移到Google Cloud AI Platform上进行云端训练和部署。
为什么需要云端训练与部署?
在机器学习项目开发过程中,我们通常会先在本地进行模型开发和测试。但随着数据量增大和模型复杂度提高,本地环境往往会遇到以下问题:
- 计算资源不足,训练时间过长
- 难以进行大规模分布式训练
- 模型部署困难,缺乏生产级API服务
Google Cloud AI Platform提供了完整的解决方案,让我们可以:
- 利用云端强大的计算资源加速训练
- 轻松实现分布式训练
- 一键部署生产级REST API服务
- 无需管理基础设施
项目准备
1. 数据上传至Google Cloud Storage
云端服务无法访问本地文件,因此我们需要先将数据上传到Google Cloud Storage(GCS):
PROJECT = "your-project-name" # 替换为你的项目ID
BUCKET = "your-bucket-name" # 替换为你的存储桶名称
REGION = "us-central1" # 选择AI Platform可用的区域
TFVERSION = "1.14" # 使用的TensorFlow版本
使用gsutil命令行工具上传数据:
gcloud config set project ${PROJECT}
gsutil mb -l ${REGION} gs://${BUCKET}
gsutil -m cp *.csv gs://${BUCKET}/taxifare/smallinput/
2. 将代码打包为Python包
AI Platform训练作业需要将代码打包成Python包进行分发。Python包的基本结构是:
taxifaremodel/
__init__.py # 标识目录为Python包
model.py # 模型代码
task.py # 任务入口
创建包结构:
mkdir taxifaremodel
touch taxifaremodel/__init__.py
3. 编写模型代码
在model.py
中,我们需要实现以下核心功能:
- 数据输入管道
- 模型架构定义
- 训练评估逻辑
- 模型导出功能
关键点包括:
- 支持从GCS读取数据
- 将检查点文件写入GCS
- 实现ServingInputReceiver用于模型部署
4. 编写任务入口
task.py
是训练任务的入口文件,主要功能是:
- 解析命令行参数
- 调用模型训练函数
- 传递参数给训练流程
import argparse
from . import model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_data_path", required=True)
parser.add_argument("--train_steps", type=int, default=1000)
parser.add_argument("--eval_data_path", required=True)
parser.add_argument("--output_dir", required=True)
parser.add_argument("--job-dir") # AI Platform要求的参数
args = parser.parse_args().__dict__
model.train_and_evaluate(args)
本地测试
在提交云端训练前,建议先在本地进行测试:
gcloud ai-platform local train \
--package-path=taxifaremodel \
--module-name=taxifaremodel.task \
-- \
--train_data_path=taxi-train.csv \
--eval_data_path=taxi-valid.csv \
--train_steps=1 \
--output_dir=taxi_trained
云端训练
确认本地测试通过后,提交云端训练任务:
OUTDIR="gs://${BUCKET}/taxifare/trained_small"
gsutil -m rm -rf ${OUTDIR} # 清空输出目录
gcloud ai-platform jobs submit training taxifare_$(date -u +%y%m%d_%H%M%S) \
--package-path=taxifaremodel \
--module-name=taxifaremodel.task \
--job-dir=gs://${BUCKET}/taxifare \
--python-version=3.5 \
--runtime-version=${TFVERSION} \
--region=${REGION} \
-- \
--train_data_path=gs://${BUCKET}/taxifare/smallinput/taxi-train.csv \
--eval_data_path=gs://${BUCKET}/taxifare/smallinput/taxi-valid.csv \
--train_steps=1000 \
--output_dir=${OUTDIR}
模型部署
训练完成后,将模型部署为生产级API服务:
- 首先检查导出的模型:
gsutil ls gs://${BUCKET}/taxifare/trained_small/export/exporter
- 创建模型和版本:
VERSION='v1'
gcloud ai-platform models create taxifare --region us-central1
gcloud ai-platform versions delete ${VERSION} --model taxifare --quiet
gcloud ai-platform versions create ${VERSION} --model taxifare \
--origin $(gsutil ls gs://${BUCKET}/taxifare/trained_small/export/exporter | tail -1) \
--python-version=3.5 \
--runtime-version ${TFVERSION}
在线预测
部署完成后,可以通过以下方式调用API:
- 使用gcloud命令行工具:
gcloud ai-platform predict \
--model taxifare \
--version ${VERSION} \
--json-instances input.json
- 直接调用REST API:
import requests
import json
data = {
"instances": [
{
"dayofweek": 3,
"hourofday": 15,
"pickuplon": -73.982155,
"pickuplat": 40.767937,
"dropofflon": -73.964630,
"dropofflat": 40.765602
}
]
}
response = requests.post(
f"https://ml.googleapis.com/v1/projects/{PROJECT}/models/taxifare/versions/{VERSION}:predict",
headers={"Authorization": f"Bearer $(gcloud auth print-access-token)"},
json=data
)
print(response.json())
最佳实践
- 版本控制:每次模型更新都创建新版本,便于回滚和AB测试
- 监控:设置监控告警,跟踪预测延迟和错误率
- 自动缩放:根据负载配置自动缩放,优化成本
- 日志记录:记录所有预测请求和结果,用于分析和调试
通过Google Cloud AI Platform,我们可以轻松将本地开发的模型迁移到云端,享受弹性计算资源和生产级部署能力,同时专注于模型开发而非基础设施管理。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考