基于Google Cloud AI Platform的机器学习模型训练与部署实战

基于Google Cloud AI Platform的机器学习模型训练与部署实战

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

本文将通过一个出租车费用预测模型的案例,详细介绍如何将本地开发的TensorFlow模型迁移到Google Cloud AI Platform上进行云端训练和部署。

为什么需要云端训练与部署?

在机器学习项目开发过程中,我们通常会先在本地进行模型开发和测试。但随着数据量增大和模型复杂度提高,本地环境往往会遇到以下问题:

  1. 计算资源不足,训练时间过长
  2. 难以进行大规模分布式训练
  3. 模型部署困难,缺乏生产级API服务

Google Cloud AI Platform提供了完整的解决方案,让我们可以:

  • 利用云端强大的计算资源加速训练
  • 轻松实现分布式训练
  • 一键部署生产级REST API服务
  • 无需管理基础设施

项目准备

1. 数据上传至Google Cloud Storage

云端服务无法访问本地文件,因此我们需要先将数据上传到Google Cloud Storage(GCS):

PROJECT = "your-project-name"  # 替换为你的项目ID
BUCKET = "your-bucket-name"   # 替换为你的存储桶名称
REGION = "us-central1"        # 选择AI Platform可用的区域
TFVERSION = "1.14"            # 使用的TensorFlow版本

使用gsutil命令行工具上传数据:

gcloud config set project ${PROJECT}
gsutil mb -l ${REGION} gs://${BUCKET}
gsutil -m cp *.csv gs://${BUCKET}/taxifare/smallinput/

2. 将代码打包为Python包

AI Platform训练作业需要将代码打包成Python包进行分发。Python包的基本结构是:

taxifaremodel/
    __init__.py    # 标识目录为Python包
    model.py       # 模型代码
    task.py        # 任务入口

创建包结构:

mkdir taxifaremodel
touch taxifaremodel/__init__.py

3. 编写模型代码

model.py中,我们需要实现以下核心功能:

  1. 数据输入管道
  2. 模型架构定义
  3. 训练评估逻辑
  4. 模型导出功能

关键点包括:

  • 支持从GCS读取数据
  • 将检查点文件写入GCS
  • 实现ServingInputReceiver用于模型部署

4. 编写任务入口

task.py是训练任务的入口文件,主要功能是:

  1. 解析命令行参数
  2. 调用模型训练函数
  3. 传递参数给训练流程
import argparse
from . import model

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_data_path", required=True)
    parser.add_argument("--train_steps", type=int, default=1000)
    parser.add_argument("--eval_data_path", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--job-dir")  # AI Platform要求的参数
    
    args = parser.parse_args().__dict__
    model.train_and_evaluate(args)

本地测试

在提交云端训练前,建议先在本地进行测试:

gcloud ai-platform local train \
    --package-path=taxifaremodel \
    --module-name=taxifaremodel.task \
    -- \
    --train_data_path=taxi-train.csv \
    --eval_data_path=taxi-valid.csv \
    --train_steps=1 \
    --output_dir=taxi_trained

云端训练

确认本地测试通过后,提交云端训练任务:

OUTDIR="gs://${BUCKET}/taxifare/trained_small"

gsutil -m rm -rf ${OUTDIR}  # 清空输出目录

gcloud ai-platform jobs submit training taxifare_$(date -u +%y%m%d_%H%M%S) \
    --package-path=taxifaremodel \
    --module-name=taxifaremodel.task \
    --job-dir=gs://${BUCKET}/taxifare \
    --python-version=3.5 \
    --runtime-version=${TFVERSION} \
    --region=${REGION} \
    -- \
    --train_data_path=gs://${BUCKET}/taxifare/smallinput/taxi-train.csv \
    --eval_data_path=gs://${BUCKET}/taxifare/smallinput/taxi-valid.csv \
    --train_steps=1000 \
    --output_dir=${OUTDIR}

模型部署

训练完成后,将模型部署为生产级API服务:

  1. 首先检查导出的模型:
gsutil ls gs://${BUCKET}/taxifare/trained_small/export/exporter
  1. 创建模型和版本:
VERSION='v1'
gcloud ai-platform models create taxifare --region us-central1
gcloud ai-platform versions delete ${VERSION} --model taxifare --quiet
gcloud ai-platform versions create ${VERSION} --model taxifare \
    --origin $(gsutil ls gs://${BUCKET}/taxifare/trained_small/export/exporter | tail -1) \
    --python-version=3.5 \
    --runtime-version ${TFVERSION}

在线预测

部署完成后,可以通过以下方式调用API:

  1. 使用gcloud命令行工具:
gcloud ai-platform predict \
    --model taxifare \
    --version ${VERSION} \
    --json-instances input.json
  1. 直接调用REST API:
import requests
import json

data = {
  "instances": [
    {
      "dayofweek": 3,
      "hourofday": 15,
      "pickuplon": -73.982155,
      "pickuplat": 40.767937,
      "dropofflon": -73.964630,
      "dropofflat": 40.765602
    }
  ]
}

response = requests.post(
    f"https://ml.googleapis.com/v1/projects/{PROJECT}/models/taxifare/versions/{VERSION}:predict",
    headers={"Authorization": f"Bearer $(gcloud auth print-access-token)"},
    json=data
)

print(response.json())

最佳实践

  1. 版本控制:每次模型更新都创建新版本,便于回滚和AB测试
  2. 监控:设置监控告警,跟踪预测延迟和错误率
  3. 自动缩放:根据负载配置自动缩放,优化成本
  4. 日志记录:记录所有预测请求和结果,用于分析和调试

通过Google Cloud AI Platform,我们可以轻松将本地开发的模型迁移到云端,享受弹性计算资源和生产级部署能力,同时专注于模型开发而非基础设施管理。

training-data-analyst Labs and demos for courses for GCP Training (http://cloud.google.com/training). training-data-analyst 项目地址: https://gitcode.com/gh_mirrors/tr/training-data-analyst

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

窦岑品

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值