32、航班实时机器学习预测:从批处理到流式处理

航班实时机器学习预测:从批处理到流式处理

1. 与航班机器学习服务交互

1.1 发送 POST 请求并解析响应

在能够创建和解析 JSON 字节后,我们可以与航班机器学习服务进行交互。该服务的部署 URL 如下:

String endpoint = "https://ml.googleapis.com/v1/projects/"
       + String.format("%s/models/%s/versions/%s:predict",
                        PROJECT, MODEL, VERSION);
GenericUrl url = new GenericUrl(endpoint);

以下是向 Google Cloud Platform 进行身份验证、发送请求并获取响应的代码:

GoogleCredential credential = GoogleCredential.getApplicationDefault();
HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
HttpContent content = new ByteArrayContent("application/json", json.getBytes());
HttpRequest request = requestFactory.buildRequest("POST", url, content);
request.setUnsuccessfulResponseHandler(
    new HttpBackOffUnsuccessfulResponseHandler(new ExponentialBackOff()));
request.setReadTimeout(5 * 60 * 1000);
String response = request.execute().parseAsString();

这里添加了两个容错机制:如果出现网络中断或临时故障(非 2xx 的 HTTP 错误),会使用 ExponentialBackoff 类增加重试间隔来重试请求。由于流量突然激增可能导致额外服务器上线时出现 100 - 150 秒的延迟,因此设置了较长的超时时间。

1.2 预测服务客户端

上述代码封装在 FlightsMLService 辅助类中,用于获取航班准点的概率:

public static double predictOntimeProbability(Flight f, double defaultValue) 
throws IOException, GeneralSecurityException {
    if (f.isNotCancelled() && f.isNotDiverted()) {
        Request request = new Request();
        Instance instance = new Instance(f);
        request.instances.add(instance);
        Response resp = sendRequest(request);
        double[] result = resp.getOntimeProbability(defaultValue);
        if (result.length > 0) {
            return result[0];
        } else {
            return defaultValue;
        }
    }
    return defaultValue;
}

需要检查航班是否取消或改道,因为神经网络仅在有实际到达信息的航班上进行训练。

2. 向航班信息添加预测

2.1 批处理输入和输出

在进行实时处理之前,先编写批处理管道实现相同功能。定义一个 Java 接口 InputOutput 来封装输入和输出:

interface InputOutput extends Serializable {
    public PCollection<Flight> readFlights(Pipeline p, MyOptions options);
    public void writeFlights(PCollection<Flight> outFlights, MyOptions options);
}

批处理数据从 BigQuery 表读取,并将结果写入 Google Cloud Storage 的文本文件。关键步骤如下表所示:
| 步骤 | 功能 | 代码 |
| — | — | — |
| 1 | 创建查询 | String query = "SELECT EVENT_DATA FROM flights.simevents WHERE STRING(FL_DATE) = '2015-01-04' AND (EVENT = 'wheelsoff' OR EVENT = 'arrived') "; |
| 2 | 读取行 | BigQueryIO.Read.fromQuery(query) |
| 3 | 解析为 Flight 对象 | TableRow row = c.element(); String line = (String) row.getOrDefault("EVENT_DATA", ""); Flight f = Flight.fromCsv(line); if (f != null) { c.outputWithTimestamp(f, f.getEventTimestamp()); } |

创建 FlightPred 对象来表示需要写入的数据:

@DefaultCoder(AvroCoder.class)
public class FlightPred {
    Flight flight;
    double ontime;
    // constructor, etc.
}

ontime 字段根据不同情况存储不同值:
- 如果收到的事件是 wheelsoff ,则存储预测的准点概率。
- 如果收到的事件是 arrived ,则存储实际的准点性能(0 或 1)。
- 如果航班取消或改道,则存储 null

写入代码如下:

PCollection<FlightPred> prds = addPrediction(outFlights);
PCollection<String> lines = predToCsv(prds);
lines.apply("Write", TextIO.Write.to(options.getOutput() + "flightPreds").withSuffix(".csv"));

addPrediction() 方法的 processElement() 方法如下:

Flight f = c.element();
double ontime = -5;
if (f.isNotCancelled() && f.isNotDiverted()) {
    if (f.getField(INPUTCOLS.EVENT).equals("arrived")) {
        ontime = f.getFieldAsFloat(INPUTCOLS.ARR_DELAY, 0) < 15 ? 1 : 0;
    } else {
        ontime = FlightsMLService.predictOntimeProbability(f, -5.0);
    }
}
c.output(new FlightPred(f, ontime));

FlightPred 转换为 CSV 文件时,会将无效的负值替换为 null

FlightPred pred = c.element();
String csv = String.join(",", pred.flight.getFields());
if (pred.ontime >= 0) {
    csv = csv + "," + new DecimalFormat("0.00").format(pred.ontime);
} else {
    csv = csv + ",";
}
c.output(csv);

2.2 数据处理管道

实现管道的中间部分,除了处理出发延迟外,与 CreateTrainingDataset 中的代码相同:

Pipeline p = Pipeline.create(options);
InputOutput io = new BatchInputOutput();
PCollection<Flight> allFlights = io.readFlights(p, options);
PCollectionView<Map<String, Double>> avgDepDelay = readAverageDepartureDelay(p, options.getDelayPath());
PCollection<Flight> hourlyFlights = CreateTrainingDataset.applyTimeWindow(allFlights);
PCollection<KV<String, Double>> avgArrDelay = CreateTrainingDataset.computeAverageArrivalDelay(hourlyFlights);
hourlyFlights = CreateTrainingDataset.addDelayInformation(hourlyFlights, avgDepDelay, avgArrDelay, averagingFrequency);
io.writeFlights(hourlyFlights, options);
PipelineResult result = p.run();
result.waitUntilFinish();

出发延迟直接从训练期间写出的全局平均值(训练数据集上)中读取:

private static PCollectionView<Map<String, Double>> readAverageDepartureDelay(Pipeline p, String path) {
    return p.apply("Read delays.csv", TextIO.Read.from(path))
           .apply("Parse delays.csv",
                ParDo.of(new DoFn<String, KV<String, Double>>() {
                    @ProcessElement
                    public void processElement(ProcessContext c) throws Exception {
                        String line = c.element();
                        String[] fields = line.split(",");
                        c.output(KV.of(fields[0], Double.parseDouble(fields[1])));
                    }
                }))
           .apply("toView", View.asMap());
}

2.3 识别低效问题

在 Google Cloud Platform 的 Cloud Dataflow 上执行批处理管道时,发现推理一直失败。通过 Google Cloud Platform 网络控制台的 API 仪表板监控 API 使用情况,发现前几个请求以 HTTP 响应代码 200 成功,但其余请求以响应代码 429(“请求过多”)失败。可以请求增加配额,但在此之前,应优化管道以减少对服务的请求。

2.4 批量请求

不需要逐个向服务发送航班信息,可以批量请求。有两种批量请求的方式:基于计数(累积航班直到达到阈值,如 100 个航班,然后使用这些航班调用服务)或基于时间(在固定时间跨度内累积航班,如两分钟,然后使用该时间段内累积的所有航班信息调用服务)。

在 Cloud Dataflow 管道中基于计数批量请求的代码如下:

.apply(Window.into(new GlobalWindows()).triggering(
    Repeatedly.forever(AfterFirst.of(
        AfterPane.elementCountAtLeast(100),
        AfterProcessingTime.pastFirstElementInPane().plusDelayOf(Duration.standardMinutes(1))))

但由于管道使用了滑动窗口来计算平均到达延迟,应用 GlobalWindow 会导致对象重新窗口化,产生不必要的副作用。因此,需要在滑动窗口的上下文中进行批量处理。

解决方案是为每个航班发出一个键值对,键反映航班的批次号:

Flight f = c.element();
String key = "batch=" + System.identityHashCode(f) % NUM_BATCHES;
c.output(KV.of(key, f));

然后使用 GroupByKey 进行分组:

PCollection<FlightPred> lines = outFlights
   .apply("Batch->Flight", ParDo.of(...))
   .apply("CreateBatches", GroupByKey.<String, Flight> create())

GroupByKey 输出的 Iterable<Flight> 可以用于批量预测:

double[] batchPredict(Iterable<Flight> flights, float defaultValue) throws IOException, GeneralSecurityException {
    Request request = new Request();
    for (Flight f : flights) {
        request.instances.add(new Instance(f));
    }
    Response resp = sendRequest(request);
    double[] result = resp.getOntimeProbability(defaultValue);
    return result;
}

通过这种更改,请求数量大幅下降,并且在配额范围内。

以下是批量请求优化的流程图:

graph TD;
    A[开始] --> B[获取航班数据];
    B --> C[发出键值对];
    C --> D[GroupByKey分组];
    D --> E[批量预测];
    E --> F[结束];

3. 流式处理管道

3.1 合并 PCollections

当处理流式数据时,数据从 Cloud Pub/Sub 读取并流入 BigQuery。之前从文本文件读取数据时, arrived wheelsoff 事件来自同一个文件,但模拟代码将这些事件流式传输到 Cloud Pub/Sub 的两个不同主题。为了重用数据处理代码,需要合并这两个 PCollection 。可以使用 Flatten 转换来实现:

// 空集合开始
PCollectionList<Flight> pcs = PCollectionList.empty(p);
// 从两个主题读取航班
for (String eventType : new String[]{"wheelsoff", "arrived"}){
    String topic = "projects/" + options.getProject() + "/topics/" + eventType;
    PCollection<Flight> flights = p.apply(eventType + ":read",
        PubsubIO.<String> read().topic(topic)
               .withCoder(StringUtf8Coder.of())
               .timestampLabel("EventTimeStamp"))
          .apply(eventType + ":parse",
               ParDo.of(new DoFn<String, Flight>() {
                    @ProcessElement
                    public void processElement(ProcessContext c) throws Exception {
                        String line = c.element();
                        Flight f = Flight.fromCsv(line);
                        if (f != null) {
                            c.output(f);
                        }
                    }
                }));
    pcs = pcs.and(flights);
}
// 合并集合
return pcs.apply(Flatten.<Flight>pCollections());

这里需要注意 PubsubIO 读取时的 timestampLabel 。如果不指定,消息将被分配插入到 Pub/Sub 时的时间戳。为了使用航班的实际时间作为时间戳,需要在发布消息时指定 EventTimeStamp 属性:

def publish(topics, allevents, notify_time):
    timestamp = notify_time.strftime(RFC3339_TIME_FORMAT)
    for key in topics:  # 'departed', 'arrived', etc.
        topic = topics[key]
        events = allevents[key]
        with topic.batch() as batch:
            for event_data in events:
                batch.publish(event_data, EventTimeStamp=timestamp)

3.2 流式输出到 BigQuery

FlightPred 信息流式传输到 BigQuery 需要创建 TableRow 并写入 BigQuery 表:

String outputTable = options.getProject() + ':' + BQ_TABLE_NAME;
TableSchema schema = new TableSchema().setFields(getTableFields());
PCollection<FlightPred> preds = addPredictionInBatches(outFlights);
PCollection<TableRow> rows = toTableRows(preds);
rows.apply("flights:write_toBQ", BigQueryIO.Write.to(outputTable)
    .withSchema(schema));

3.3 执行流式处理管道

编写好管道代码后,可以启动第 4 章的模拟程序将记录流式传输到 Cloud Pub/Sub:

cd ../04_streaming/simulate
python simulate.py --startTime "2015-04-01 00:00:00 UTC" --endTime "2015-04-03 00:00:00 UTC" --speedFactor 60

然后启动 Cloud Dataflow 管道来消费这些记录,调用航班机器学习服务,并将结果记录写入 BigQuery:

mvn compile exec:java \
-Dexec.mainClass=com.google.cloud.training.flights.AddRealtimePrediction \
-Dexec.args="--realtime --speedupFactor=60 --maxNumWorkers=10 --autoscalingAlgorithm=THROUGHPUT_BASED"

通过 realtime 选项可以在批处理和实时处理的 InputOutput 实现之间切换。

3.4 处理延迟和乱序记录

在实际应用中,航班记录可能会延迟或乱序到达。可以通过在模拟程序使用的 BigQuery SQL 语句中添加随机延迟来模拟这种情况:

SELECT
    EVENT,
    NOTIFY_TIME AS ORIGINAL_NOTIFY_TIME,
    TIMESTAMP_ADD(NOTIFY_TIME, INTERVAL CAST (0.5 + RAND()*120 AS INT64) SECOND) AS NOTIFY_TIME,
    EVENT_DATA
FROM
    `flights.simevents`

这里使用 RAND() 函数生成 0 到 1 之间的随机数,乘以 120 得到 0 到 2 分钟的延迟。

还可以模拟其他延迟分布:
- 均匀分布延迟 :如果要模拟 90 到 120 秒的延迟,可以使用 CAST(90.5 + RAND()*30 AS INT64)
- 指数分布延迟 :使用 CAST(-LN(RAND()*0.99 + 0.01)*30 + 90.5 AS INT64) 来模拟指数分布延迟。
- 正态分布延迟 :在 BigQuery 中使用用户定义函数(UDF)来生成正态分布的随机变量:

js = """
    var u = 1 - Math.random();
    var v = 1 - Math.random();
    var f = Math.sqrt(-2 * Math.log(u)) * Math.cos(2*Math.PI*v);
    f = f * sigma + mu;
    if (f < 0)
        return 0;
    else
        return f;
""".replace('\n', ' ')

然后在 SQL 中创建临时 UDF 并使用:

sql = """
CREATE TEMPORARY FUNCTION
trunc_rand_normal(x FLOAT64, mu FLOAT64, sigma FLOAT64)
RETURNS FLOAT64
LANGUAGE js AS "{}";
SELECT
    trunc_rand_normal(ARR_DELAY, 90, 15) AS JITTER
FROM
    ...
""".format(js).replace('\n', ' ')

为了实验不同类型的延迟,可以修改模拟代码:

jitter = 'CAST (-LN(RAND()*0.99 + 0.01)*30 + 90.5 AS INT64)'
# 运行查询以提取模拟事件
querystr = """\
SELECT
    EVENT,
    TIMESTAMP_ADD(NOTIFY_TIME, INTERVAL {} SECOND) AS NOTIFY_TIME,
    EVENT_DATA
FROM
    `cloud-training-demos.flights.simevents`
WHERE
    NOTIFY_TIME >= TIMESTAMP('{}')
    AND NOTIFY_TIME < TIMESTAMP('{}')
ORDER BY
    NOTIFY_TIME
""".format(jitter, start_time, end_time)

以下是流式处理管道的流程图:

graph TD;
    A[开始] --> B[从 Cloud Pub/Sub 读取数据];
    B --> C[合并 PCollections];
    C --> D[添加预测信息];
    D --> E[转换为 TableRow];
    E --> F[写入 BigQuery];
    F --> G[结束];

综上所述,通过上述步骤,我们可以从批处理过渡到流式处理,实现航班实时机器学习预测,并处理可能出现的延迟和乱序记录。在实际应用中,可以根据具体需求调整代码和参数,以达到最佳性能和准确性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值