航班实时机器学习预测:从批处理到流式处理
1. 与航班机器学习服务交互
1.1 发送 POST 请求并解析响应
在能够创建和解析 JSON 字节后,我们可以与航班机器学习服务进行交互。该服务的部署 URL 如下:
String endpoint = "https://ml.googleapis.com/v1/projects/"
+ String.format("%s/models/%s/versions/%s:predict",
PROJECT, MODEL, VERSION);
GenericUrl url = new GenericUrl(endpoint);
以下是向 Google Cloud Platform 进行身份验证、发送请求并获取响应的代码:
GoogleCredential credential = GoogleCredential.getApplicationDefault();
HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
HttpContent content = new ByteArrayContent("application/json", json.getBytes());
HttpRequest request = requestFactory.buildRequest("POST", url, content);
request.setUnsuccessfulResponseHandler(
new HttpBackOffUnsuccessfulResponseHandler(new ExponentialBackOff()));
request.setReadTimeout(5 * 60 * 1000);
String response = request.execute().parseAsString();
这里添加了两个容错机制:如果出现网络中断或临时故障(非 2xx 的 HTTP 错误),会使用
ExponentialBackoff
类增加重试间隔来重试请求。由于流量突然激增可能导致额外服务器上线时出现 100 - 150 秒的延迟,因此设置了较长的超时时间。
1.2 预测服务客户端
上述代码封装在
FlightsMLService
辅助类中,用于获取航班准点的概率:
public static double predictOntimeProbability(Flight f, double defaultValue)
throws IOException, GeneralSecurityException {
if (f.isNotCancelled() && f.isNotDiverted()) {
Request request = new Request();
Instance instance = new Instance(f);
request.instances.add(instance);
Response resp = sendRequest(request);
double[] result = resp.getOntimeProbability(defaultValue);
if (result.length > 0) {
return result[0];
} else {
return defaultValue;
}
}
return defaultValue;
}
需要检查航班是否取消或改道,因为神经网络仅在有实际到达信息的航班上进行训练。
2. 向航班信息添加预测
2.1 批处理输入和输出
在进行实时处理之前,先编写批处理管道实现相同功能。定义一个 Java 接口
InputOutput
来封装输入和输出:
interface InputOutput extends Serializable {
public PCollection<Flight> readFlights(Pipeline p, MyOptions options);
public void writeFlights(PCollection<Flight> outFlights, MyOptions options);
}
批处理数据从 BigQuery 表读取,并将结果写入 Google Cloud Storage 的文本文件。关键步骤如下表所示:
| 步骤 | 功能 | 代码 |
| — | — | — |
| 1 | 创建查询 |
String query = "SELECT EVENT_DATA FROM flights.simevents WHERE STRING(FL_DATE) = '2015-01-04' AND (EVENT = 'wheelsoff' OR EVENT = 'arrived') ";
|
| 2 | 读取行 |
BigQueryIO.Read.fromQuery(query)
|
| 3 | 解析为
Flight
对象 |
TableRow row = c.element(); String line = (String) row.getOrDefault("EVENT_DATA", ""); Flight f = Flight.fromCsv(line); if (f != null) { c.outputWithTimestamp(f, f.getEventTimestamp()); }
|
创建
FlightPred
对象来表示需要写入的数据:
@DefaultCoder(AvroCoder.class)
public class FlightPred {
Flight flight;
double ontime;
// constructor, etc.
}
ontime
字段根据不同情况存储不同值:
- 如果收到的事件是
wheelsoff
,则存储预测的准点概率。
- 如果收到的事件是
arrived
,则存储实际的准点性能(0 或 1)。
- 如果航班取消或改道,则存储
null
。
写入代码如下:
PCollection<FlightPred> prds = addPrediction(outFlights);
PCollection<String> lines = predToCsv(prds);
lines.apply("Write", TextIO.Write.to(options.getOutput() + "flightPreds").withSuffix(".csv"));
addPrediction()
方法的
processElement()
方法如下:
Flight f = c.element();
double ontime = -5;
if (f.isNotCancelled() && f.isNotDiverted()) {
if (f.getField(INPUTCOLS.EVENT).equals("arrived")) {
ontime = f.getFieldAsFloat(INPUTCOLS.ARR_DELAY, 0) < 15 ? 1 : 0;
} else {
ontime = FlightsMLService.predictOntimeProbability(f, -5.0);
}
}
c.output(new FlightPred(f, ontime));
将
FlightPred
转换为 CSV 文件时,会将无效的负值替换为
null
:
FlightPred pred = c.element();
String csv = String.join(",", pred.flight.getFields());
if (pred.ontime >= 0) {
csv = csv + "," + new DecimalFormat("0.00").format(pred.ontime);
} else {
csv = csv + ",";
}
c.output(csv);
2.2 数据处理管道
实现管道的中间部分,除了处理出发延迟外,与
CreateTrainingDataset
中的代码相同:
Pipeline p = Pipeline.create(options);
InputOutput io = new BatchInputOutput();
PCollection<Flight> allFlights = io.readFlights(p, options);
PCollectionView<Map<String, Double>> avgDepDelay = readAverageDepartureDelay(p, options.getDelayPath());
PCollection<Flight> hourlyFlights = CreateTrainingDataset.applyTimeWindow(allFlights);
PCollection<KV<String, Double>> avgArrDelay = CreateTrainingDataset.computeAverageArrivalDelay(hourlyFlights);
hourlyFlights = CreateTrainingDataset.addDelayInformation(hourlyFlights, avgDepDelay, avgArrDelay, averagingFrequency);
io.writeFlights(hourlyFlights, options);
PipelineResult result = p.run();
result.waitUntilFinish();
出发延迟直接从训练期间写出的全局平均值(训练数据集上)中读取:
private static PCollectionView<Map<String, Double>> readAverageDepartureDelay(Pipeline p, String path) {
return p.apply("Read delays.csv", TextIO.Read.from(path))
.apply("Parse delays.csv",
ParDo.of(new DoFn<String, KV<String, Double>>() {
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
String line = c.element();
String[] fields = line.split(",");
c.output(KV.of(fields[0], Double.parseDouble(fields[1])));
}
}))
.apply("toView", View.asMap());
}
2.3 识别低效问题
在 Google Cloud Platform 的 Cloud Dataflow 上执行批处理管道时,发现推理一直失败。通过 Google Cloud Platform 网络控制台的 API 仪表板监控 API 使用情况,发现前几个请求以 HTTP 响应代码 200 成功,但其余请求以响应代码 429(“请求过多”)失败。可以请求增加配额,但在此之前,应优化管道以减少对服务的请求。
2.4 批量请求
不需要逐个向服务发送航班信息,可以批量请求。有两种批量请求的方式:基于计数(累积航班直到达到阈值,如 100 个航班,然后使用这些航班调用服务)或基于时间(在固定时间跨度内累积航班,如两分钟,然后使用该时间段内累积的所有航班信息调用服务)。
在 Cloud Dataflow 管道中基于计数批量请求的代码如下:
.apply(Window.into(new GlobalWindows()).triggering(
Repeatedly.forever(AfterFirst.of(
AfterPane.elementCountAtLeast(100),
AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.standardMinutes(1))))
但由于管道使用了滑动窗口来计算平均到达延迟,应用
GlobalWindow
会导致对象重新窗口化,产生不必要的副作用。因此,需要在滑动窗口的上下文中进行批量处理。
解决方案是为每个航班发出一个键值对,键反映航班的批次号:
Flight f = c.element();
String key = "batch=" + System.identityHashCode(f) % NUM_BATCHES;
c.output(KV.of(key, f));
然后使用
GroupByKey
进行分组:
PCollection<FlightPred> lines = outFlights
.apply("Batch->Flight", ParDo.of(...))
.apply("CreateBatches", GroupByKey.<String, Flight> create())
GroupByKey
输出的
Iterable<Flight>
可以用于批量预测:
double[] batchPredict(Iterable<Flight> flights, float defaultValue) throws IOException, GeneralSecurityException {
Request request = new Request();
for (Flight f : flights) {
request.instances.add(new Instance(f));
}
Response resp = sendRequest(request);
double[] result = resp.getOntimeProbability(defaultValue);
return result;
}
通过这种更改,请求数量大幅下降,并且在配额范围内。
以下是批量请求优化的流程图:
graph TD;
A[开始] --> B[获取航班数据];
B --> C[发出键值对];
C --> D[GroupByKey分组];
D --> E[批量预测];
E --> F[结束];
3. 流式处理管道
3.1 合并 PCollections
当处理流式数据时,数据从 Cloud Pub/Sub 读取并流入 BigQuery。之前从文本文件读取数据时,
arrived
和
wheelsoff
事件来自同一个文件,但模拟代码将这些事件流式传输到 Cloud Pub/Sub 的两个不同主题。为了重用数据处理代码,需要合并这两个
PCollection
。可以使用
Flatten
转换来实现:
// 空集合开始
PCollectionList<Flight> pcs = PCollectionList.empty(p);
// 从两个主题读取航班
for (String eventType : new String[]{"wheelsoff", "arrived"}){
String topic = "projects/" + options.getProject() + "/topics/" + eventType;
PCollection<Flight> flights = p.apply(eventType + ":read",
PubsubIO.<String> read().topic(topic)
.withCoder(StringUtf8Coder.of())
.timestampLabel("EventTimeStamp"))
.apply(eventType + ":parse",
ParDo.of(new DoFn<String, Flight>() {
@ProcessElement
public void processElement(ProcessContext c) throws Exception {
String line = c.element();
Flight f = Flight.fromCsv(line);
if (f != null) {
c.output(f);
}
}
}));
pcs = pcs.and(flights);
}
// 合并集合
return pcs.apply(Flatten.<Flight>pCollections());
这里需要注意
PubsubIO
读取时的
timestampLabel
。如果不指定,消息将被分配插入到 Pub/Sub 时的时间戳。为了使用航班的实际时间作为时间戳,需要在发布消息时指定
EventTimeStamp
属性:
def publish(topics, allevents, notify_time):
timestamp = notify_time.strftime(RFC3339_TIME_FORMAT)
for key in topics: # 'departed', 'arrived', etc.
topic = topics[key]
events = allevents[key]
with topic.batch() as batch:
for event_data in events:
batch.publish(event_data, EventTimeStamp=timestamp)
3.2 流式输出到 BigQuery
将
FlightPred
信息流式传输到 BigQuery 需要创建
TableRow
并写入 BigQuery 表:
String outputTable = options.getProject() + ':' + BQ_TABLE_NAME;
TableSchema schema = new TableSchema().setFields(getTableFields());
PCollection<FlightPred> preds = addPredictionInBatches(outFlights);
PCollection<TableRow> rows = toTableRows(preds);
rows.apply("flights:write_toBQ", BigQueryIO.Write.to(outputTable)
.withSchema(schema));
3.3 执行流式处理管道
编写好管道代码后,可以启动第 4 章的模拟程序将记录流式传输到 Cloud Pub/Sub:
cd ../04_streaming/simulate
python simulate.py --startTime "2015-04-01 00:00:00 UTC" --endTime "2015-04-03 00:00:00 UTC" --speedFactor 60
然后启动 Cloud Dataflow 管道来消费这些记录,调用航班机器学习服务,并将结果记录写入 BigQuery:
mvn compile exec:java \
-Dexec.mainClass=com.google.cloud.training.flights.AddRealtimePrediction \
-Dexec.args="--realtime --speedupFactor=60 --maxNumWorkers=10 --autoscalingAlgorithm=THROUGHPUT_BASED"
通过
realtime
选项可以在批处理和实时处理的
InputOutput
实现之间切换。
3.4 处理延迟和乱序记录
在实际应用中,航班记录可能会延迟或乱序到达。可以通过在模拟程序使用的 BigQuery SQL 语句中添加随机延迟来模拟这种情况:
SELECT
EVENT,
NOTIFY_TIME AS ORIGINAL_NOTIFY_TIME,
TIMESTAMP_ADD(NOTIFY_TIME, INTERVAL CAST (0.5 + RAND()*120 AS INT64) SECOND) AS NOTIFY_TIME,
EVENT_DATA
FROM
`flights.simevents`
这里使用
RAND()
函数生成 0 到 1 之间的随机数,乘以 120 得到 0 到 2 分钟的延迟。
还可以模拟其他延迟分布:
-
均匀分布延迟
:如果要模拟 90 到 120 秒的延迟,可以使用
CAST(90.5 + RAND()*30 AS INT64)
。
-
指数分布延迟
:使用
CAST(-LN(RAND()*0.99 + 0.01)*30 + 90.5 AS INT64)
来模拟指数分布延迟。
-
正态分布延迟
:在 BigQuery 中使用用户定义函数(UDF)来生成正态分布的随机变量:
js = """
var u = 1 - Math.random();
var v = 1 - Math.random();
var f = Math.sqrt(-2 * Math.log(u)) * Math.cos(2*Math.PI*v);
f = f * sigma + mu;
if (f < 0)
return 0;
else
return f;
""".replace('\n', ' ')
然后在 SQL 中创建临时 UDF 并使用:
sql = """
CREATE TEMPORARY FUNCTION
trunc_rand_normal(x FLOAT64, mu FLOAT64, sigma FLOAT64)
RETURNS FLOAT64
LANGUAGE js AS "{}";
SELECT
trunc_rand_normal(ARR_DELAY, 90, 15) AS JITTER
FROM
...
""".format(js).replace('\n', ' ')
为了实验不同类型的延迟,可以修改模拟代码:
jitter = 'CAST (-LN(RAND()*0.99 + 0.01)*30 + 90.5 AS INT64)'
# 运行查询以提取模拟事件
querystr = """\
SELECT
EVENT,
TIMESTAMP_ADD(NOTIFY_TIME, INTERVAL {} SECOND) AS NOTIFY_TIME,
EVENT_DATA
FROM
`cloud-training-demos.flights.simevents`
WHERE
NOTIFY_TIME >= TIMESTAMP('{}')
AND NOTIFY_TIME < TIMESTAMP('{}')
ORDER BY
NOTIFY_TIME
""".format(jitter, start_time, end_time)
以下是流式处理管道的流程图:
graph TD;
A[开始] --> B[从 Cloud Pub/Sub 读取数据];
B --> C[合并 PCollections];
C --> D[添加预测信息];
D --> E[转换为 TableRow];
E --> F[写入 BigQuery];
F --> G[结束];
综上所述,通过上述步骤,我们可以从批处理过渡到流式处理,实现航班实时机器学习预测,并处理可能出现的延迟和乱序记录。在实际应用中,可以根据具体需求调整代码和参数,以达到最佳性能和准确性。
超级会员免费看
8

被折叠的 条评论
为什么被折叠?



