34、实时机器学习:从查询到模型评估与持续训练

实时机器学习:从查询到模型评估与持续训练

1. 从 Cloud Bigtable 查询数据

使用 BigQuery 作为数据接收器时,即使在数据流式传输过程中,也能使用 SQL 进行分析。而 Cloud Bigtable 虽然也提供流式分析,但不支持 SQL。由于它是 NoSQL 存储,通常需要手动编写客户端应用程序。不过,我们可以使用 HBase 命令行 shell 来查询表内容。

例如,通过扫描表并将结果限制为一行,可获取数据库中的最新行:

scan 'predictions', {'LIMIT' => 1}

执行上述命令后,输出如下:

hbase(main):006:0> scan 'predictions', {'LIMIT' => 1}
ROW                              COLUMN+CELL
ABE#ATL#DL#9223370608969975807  column=FL:AIRLINE_ID, timestamp=1427891940, 
value=19790                                   
ABE#ATL#DL#9223370608969975807  column=FL:ARR_AIRPORT_LAT, timestamp=1427891940, 
value=33.63666667                        
ABE#ATL#DL#9223370608969975807  column=FL:ARR_AIRPORT_LON, timestamp=1427891940, 
value=-84.42777778  
…

由于行按行键升序排序,所以结果按出发机场、到达机场和反向时间戳排列。这就是为什么我们能获取以字母 A 开头的两个机场之间的最新航班信息。命令行 shell 会为每个单元格输出一行,因此即使这些行都属于同一行(注意行键相同),我们也会得到多行输出。

行键的设计优势在于能够获取一对机场之间的最后几趟航班信息。例如,要获取美国航空公司(AA)从芝加哥奥黑尔机场(ORD)到洛杉矶(LAX)的最新两趟航班的准时性和事件列信息,可使用以下命令:

scan 'predictions', {STARTROW => 'ORD#LAX#AA', ENDROW => 'ORD#LAX#AB', 
                     COLUMN => ['FL:ontime','FL:EVENT'], LIMIT => 2}

执行结果如下:

ROW                                          COLUMN+CELL
 ORD#LAX#AA#9223370608929475807              column=FL:EVENT, timestamp=14279262
                                             00, value=wheelsoff
 ORD#LAX#AA#9223370608929475807              column=FL:ontime, timestamp=1427926
                                             200, value=0.73
 ORD#LAX#AA#9223370608939975807              column=FL:EVENT, timestamp=14279320
                                             80, value=arrived
 ORD#LAX#AA#9223370608939975807              column=FL:ontime, timestamp=1427932
                                             080, value=1.00

注意,到达事件显示的是实际准点率(1.00),而起飞事件显示的是预测准点到达概率(0.73)。我们可以编写一个使用 Cloud Bigtable 客户端 API 的应用程序,在 Cloud Bigtable 表上执行此类查询,并将结果展示给最终用户。

2. 评估模型性能

最终模型的性能只能通过真正独立的数据进行评估。由于在评估不同模型和进行超参数调优时使用了“测试”数据,因此不能使用 2015 年的航班数据来评估模型性能。幸运的是,从开始到现在已经过去了足够长的时间,美国运输统计局(BTS)网站上已有 2016 年的数据。我们可以使用 2016 年的数据来评估机器学习模型。

当然,环境已经发生了变化,2016 年运营的航空公司列表可能与 2015 年不同,航空公司调度员可能也改变了航班调度方式,经济状况的变化可能导致航班更满(从而导致登机时间更长)。尽管如此,使用 2016 年的数据进行评估是正确的做法,因为在现实世界中,我们可能在 2016 年仍在使用 2015 年的模型为用户提供服务,我们需要了解这些预测的效果如何。

3. 持续训练的必要性

准备 2016 年的数据需要重复处理 2015 年数据时所执行的步骤:
1. 修改 02_ingest 中的 ingest_2015.sh,以下载和清理 2016 年的文件,并将其上传到云端。确保上传到不同的存储桶或输出目录,以避免 2015 年和 2016 年的数据集混合。
2. 修改 04_streaming/simulate/df06.py,使其指向 2016 年的 Cloud Storage 存储桶,并将结果保存到不同的 BigQuery 数据集(例如 flights2016)。此脚本会进行时区校正,并向数据集中添加机场位置信息(纬度、经度)。

在执行这些步骤时,发现第二步(df06.py)失败了,原因是该脚本使用的 airports.csv 文件在 2016 年不完整。新建了一些机场,部分机场位置也发生了变化,因此 2016 年有一些独特的机场代码在 2015 年的文件中不存在。虽然可以获取 2016 年对应的 airports.csv 文件,但这并不能解决一个更大的问题。在机器学习模型中,我们通过创建出发和到达机场的嵌入来使用机场位置信息,对于新机场,这些特征将无法正常工作。

在现实世界中,尤其是处理与人类和人类制品(如客户、机场、产品等)相关的数据时,不太可能只训练一次模型并一直使用下去。相反,模型需要使用最新数据进行持续训练。持续训练是机器学习中不可或缺的一部分,因此 Cloud ML Engine 强调易于操作和版本控制,这是需要自动化的工作流程。

这里所说的是模型的持续训练,而不是从头开始重新训练。在训练模型时,我们会保存检查点。要使用新数据训练模型,我们可以从这样的检查点模型开始,加载它,然后运行几个批次的数据,从而调整权重。这使模型能够在不丢失已积累知识的情况下慢慢适应新数据。我们还可以替换检查点图中的节点,或者冻结某些层,只训练其他层(例如机场的嵌入层)。如果已经进行了学习率衰减等操作,应继续以较低的学习率进行训练,而不是以最高学习率重新开始训练。Cloud ML Engine 和 TensorFlow 就是为此而设计的。

目前,我们可以简单地修改查找机场代码的代码,以优雅地处理错误,并为机器学习代码提供合理的估算值:

def airport_timezone(airport_id):
       if airport_id in airport_timezones_dict:
          return airport_timezones_dict[airport_id]
       else:
          return ('37.52', '-92.17', u'America/Chicago')

如果在字典中未找到机场,机场位置将被假定为(37.52, –92.17),这对应于美国人口普查局估计的美国人口中心位置。

4. 评估管道

创建 2016 年的事件数据后,我们可以在 Cloud Dataflow 中创建一个评估管道,该管道结合了创建训练数据集和进行实时预测的管道。具体步骤和相应代码如下表所示:
| 步骤 | 操作 | 代码 |
| — | — | — |
| 1 | 输入查询 | SELECT EVENT_DATA FROM flights2016.simevents WHERE (EVENT = 'wheelsoff' OR EVENT = 'arrived') |
| 2 | 读取到航班的 PCollection 中 | CreateTrainingDataset.readFlights |
| 3 | 获取平均出发延误时间 | AddRealtimePrediction.readAverageDepartureDelay |
| 4 | 应用移动平均窗口 | CreateTrainingDataset.applyTimeWindow |
| 5 | 计算平均到达延误时间 | CreateTrainingDataset.computeAverageArrivalDelay |
| 6 | 将延误信息添加到航班数据中 | CreateTrainingDataset.addDelayInformation |
| 7 | 仅保留已到达的航班,因为只有这些航班才有真实的准点率 | Flight f = c.element(); if (f.getField(INPUTCOLS.EVENT).equals("arrived")) { c.output(f); } |
| 8 | 调用航班机器学习服务 | InputOutput.addPredictionInBatches (part of real-time pipeline) |
| 9 | 输出结果 | FlightPred fp = c.element(); String csv = fp.ontime + "," + fp.flight.toTrainingCsv(); c.output(csv); |

从这个批处理管道调用航班机器学习服务时,需要将在线预测的配额提高到 10,000 个请求/秒。

graph LR
    A[输入查询] --> B[读取到航班的 PCollection 中]
    B --> C[获取平均出发延误时间]
    C --> D[应用移动平均窗口]
    D --> E[计算平均到达延误时间]
    E --> F[将延误信息添加到航班数据中]
    F --> G[仅保留已到达的航班]
    G --> H[调用航班机器学习服务]
    H --> I[输出结果]
5. 评估模型性能

评估管道会将评估数据集写入 Cloud Storage 上的文本文件。为了方便分析,我们可以将整个数据集批量加载到 BigQuery 中作为一个表:

bq load -F , flights2016.eval \
   gs://cloud-training-demos-ml/flights/chapter10/eval/*.csv \
   evaldata.json

有了这个 BigQuery 表,我们就可以快速轻松地对模型性能进行分析。例如,我们可以在 BigQuery 控制台中输入以下内容来计算均方根误差(RMSE):

SELECT
  SQRT(AVG((PRED-ONTIME)*(PRED-ONTIME)))
FROM
  flights2016.eval

计算结果显示 RMSE 为 0.24。回想一下,在训练过程中的验证数据集上,RMSE 为 0.187。对于一个来自完全不同时间段(有新机场,且可能有不同的空中交通管制和航空公司调度策略)的完全独立的数据集来说,这个结果相当合理。

6. 边际分布

我们可以使用 Cloud Datalab 更深入地分析模型的性能特征。通过 BigQuery 按预测概率和实际准点情况对统计数据进行聚合:

sql = """
SELECT
  pred, ontime, count(pred) as num_pred, count(ontime) as num_ontime
FROM
  flights2016.eval
GROUP BY pred, ontime
"""
df = bq.Query(sql).execute().result().to_dataframe()

根据上述代码得到的 Pandas 数据框,包含了每对预测概率(四舍五入到小数点后两位)和实际准点到达情况的总预测数和实际准点到达数。

我们可以通过对 Pandas 数据框按实际准点到达情况进行切片,并按预测概率绘制预测数量,来绘制边际分布。结果表明,绝大多数准点到达的航班都被正确预测,预测概率小于 0.4 的准点到达航班数量相当少。同样,另一个边际分布也显示,在绝大多数情况下,晚点到达的航班也被正确预测。

这种使用 BigQuery 聚合大型数据集,以及使用 Cloud Datalab 上的 Pandas 和 matplotlib 或 seaborn 可视化结果的模式非常强大。matplotlib 和 seaborn 有强大的绘图和可视化工具,但无法处理非常大的数据集。BigQuery 能够查询大型数据集并提炼出关键信息,但它不是一个可视化引擎。Cloud Datalab 能够将 BigQuery 结果集转换为 Pandas 数据框,充当了强大的后端数据处理引擎(BigQuery)和强大的可视化工具(matplotlib/seaborn)之间的桥梁。

将查询委托给无服务器后端的用途不仅限于数据科学。例如,结合使用 Cloud Datalab 和 Stackdriver 可以将强大的绘图功能与大规模日志提取和查询相结合。

7. 检查模型行为

我们还可以检查模型对数据的理解是否合理。例如,通过以下查询可以得到模型预测与出发延误之间的关系:

SELECT
  pred, AVG(dep_delay) as dep_delay
FROM
  flights2016.eval
GROUP BY pred

将查询结果绘制成图后,得到的图形非常平滑且合理。延误 30 分钟或更长时间的航班的预测概率接近 0,而提前出发(即出发延误为负数)的航班的预测概率超过 0.8,且过渡平滑渐进。

我们还可以分析出发延误与模型误差(定义为航班准点到达但预测概率小于 0.5 的情况)之间的关系:

SELECT
  pred, AVG(dep_delay) as dep_delay
FROM
  flights2016.eval
WHERE (ontime > 0.5 and pred <= 0.5) or (ontime < 0.5 and pred > 0.5)
GROUP BY pred

将查询结果绘制成图后,可以看出模型的许多误差出现在晚点出发但仍准点到达的航班上。

我们可以类似地按每个输入变量(如滑行距离、出发机场等)分析模型的性能。为了简洁起见,我们仅选择一个分类变量——航空公司来评估模型。

查询按分类变量进行分组和聚合:

SELECT
  carrier, pred, AVG(dep_delay) as dep_delay
FROM
  flights2016.eval
WHERE (ontime > 0.5 and pred <= 0.5) or (ontime < 0.5 and pred > 0.5)
GROUP BY pred, carrier

Pandas 绘图代码会提取相应的组,并为每个航空公司绘制单独的线条:

df = df.sort_values(by='pred')
fig, ax = plt.subplots(figsize=(8,12))
for label, dfg in df.groupby('carrier'):
    dfg.plot(x='pred', y='dep_delay', ax=ax, label=label)
plt.legend();
8. 识别行为变化

从绘制的图形中可以看出,对于误判为晚点但实际准点到达的航班(假阳性),由 OO(SkyWest 航空公司)运营的航班的平均出发延误最低。相反,对于漏判(预测准点但实际晚点)的航班,由 WN(西南航空公司)运营的航班的平均出发延误最高。

这些观察结果促使我们研究这两家航空公司在 2015 年至 2016 年运营方式的变化,并确保在使用新数据重新训练模型时捕捉到这些差异。例如,Skywest 可能改变了其调度算法,在名义航班时间上增加了额外的时间,导致原本在 2015 年晚点的航班在 2016 年准点到达。

我们可以进行抽样检查来验证这一点。首先,找出 OO 运营的最常见航班:

SELECT
  origin,
  dest,
  COUNT(*) as num_flights
FROM
  flights2016.eval
WHERE
  carrier = 'OO'
GROUP BY
  origin,
  dest
order by num_flights desc limit 10

然后,查看该航班(ORD - MKE)在 2015 年和 2016 年的统计数据:

SELECT
  APPROX_QUANTILES(TIMESTAMP_DIFF(CRS_ARR_TIME, CRS_DEP_TIME, SECOND), 5)
FROM
  flights2016.simevents
WHERE
  carrier = 'OO' and origin = 'ORD' and dest = 'MKE'

结果显示,2016 年从芝加哥(ORD)到密尔沃基(MKE)的航班计划时间的分位数为 2580、2820、2880、2940、3180,而 2015 年对应的分位数为 2400、2760、2880、3060、3300。这表明航空公司在原本时间紧张的航班上增加了时间,而在较长航班上减少了时间。这种系统底层统计数据的变化说明了持续进行机器学习训练的必要性。

9. 总结

整个过程完成了从数据摄取到机器学习的端到端流程。我们将训练好的机器学习模型部署为微服务,并将其嵌入到实时流式 Cloud Dataflow 管道中。为此,我们需要将航班对象转换为 JSON 请求,并将 JSON 响应转换为 FlightPred 对象以供管道使用。同时,我们发现逐个发送航班请求在网络、成本和时间方面可能代价高昂,因此在 Cloud Dataflow 管道中对机器学习服务的请求进行了批量处理。

管道从 Cloud Pub/Sub 读取数据,将每个主题的 PCollection 扁平化到一个公共的 PCollection 中,并使用与训练时相同的代码进行处理。在服务和训练中使用相同的代码有助于减少训练 - 服务偏差。我们还使用水印和触发器来更精细地控制如何处理延迟到达、乱序的记录。

我们还探索了其他可能的流式接收器以及如何在它们之间进行选择。作为示例,我们实现了一个 Cloud Bigtable 接收器,适用于需要高吞吐量和低延迟的情况。我们设计了合适的行键,以便在实现快速读取的同时并行化写入操作。

最后,我们研究了如何评估模型性能,并识别由于所建模系统中关键参与者的行为变化导致模型不再有效的情况。通过发现 Skywest 航空公司在 2015 年至 2016 年航班调度的变化,我们深刻认识到持续进行机器学习训练的重要性。在构建了端到端系统后,我们需要不断改进它,并使用最新数据进行更新。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值