一、简介
Wide & Deep Learning (以下简称 WDL)是解决点击率预估(CTR Prediction)问题比较重要的模型。WDL 在训练时,也面临着点击率预估领域存在的两个挑战:巨大的词表(Embedding Table),以及大量的数据吞吐。
业界比较有影响力的包含了 WDL 解决方案及评测结果的项目有 HugeCTR,该框架通过模型并行、三级流水线等技巧,解决了以上问题。在2020年 MLPerf 评测中,英伟达用 HugeCTR 实现了当时最快的 WDL 模型。
英伟达 Blog 给出的数据显示,在特定的、对齐后的硬件条件下,HugeCTR 的速度是 TensorFlow-CPU 的114倍,是 TensorFlow-GPU 的7.4倍,下图的纵坐标代表每轮迭代的延迟,数值越小,意味着性能越好:
而 OneFlow-WDL 比 HugeCTR 更快,每轮迭代延迟比 HugeCTR 更少:
上图节取自 DLPerf 中的 WDL(OneFlow-WDL vs. HugeCTR)评测报告,展示了 vocabulary size 倍增实验的结果,横坐标为实验时所取的 vocabulary size 参数大小,不断翻倍,直到超出框架的最大负荷(OOM 为 Out of Memory 的缩写)。
可以看到,相同条件下的各组实验中,OneFlow 的训练速度比 HugeCTR 快。并且,随着 vocabulary size 的增大,OneFlow 的每次迭代的 latency 几乎无变化,性能无损失。
使用 OneFlow 在 32 张 V100-SXM2-16GB 组成的集群上训练 WDL,可以支持8亿(819200000)大小的词表。
原始的日志数据、更详细的图表说明,可以参考最近公布的 DLPerf 关于 OneFlow 与 HugeCTR 实现相同结构 WDL 的性能测试报告。
OneFlow 作为一款通用的深度学习框架,所实现的 OneFlow-WDL 模型性能却超越了专为 CTR Prediction 问题设计的 HugeCTR 框架,内部的技术原理有哪些呢?本文将详细揭秘 OneFlow-WDL 实现的技术细节,并将不同的场景中 OneFlow-WDL 的多种分布式实现方案做详细的横向对比与分析,以方便读者根据自身需求与应用场景,选择最适合的方案。
熟悉 WDL 模型的读者,可以直接跳到第三节“如何在 OneFlow 中实现分布式 WDL” 开始阅读。
二、WDL、大词表与 OneFlow
OneFlow-WDL 为什么这么快,OneFlow 作为分布式最易用的深度学习框架,在实现 OneFlow-WDL 的过程中有哪些过人之处?这个问题难以简单直接地回答,为此我们准备了此节内容,将依次介绍:
- CTR Prediction 是什么?WDL 是什么?它们解决了什么问题?
- WDL 模型在实际工程中为什么需要分布式?为什么困难?
- OneFlow 为什么适合解决 WDL 这样的大模型问题?
读者可以根据自身情况,略过已经了解的部分,挑选自己还不了解的部分阅读即可。
2.1 CTR Prediction 与 WDL
2.1.1 CTR Prediction 问题
CTR Prediction 问题的目标是预测用户点击率,点击率作为一种指标,可以一定程度地反映用户对所提供的内容的感兴趣程度。因此,CTR Prediction 技术广泛应用在推荐、排序搜索、在线广告等领域。互联网公司大部分的服务都或多或少与 CTR Prediction 有关系,无论是 BAT 各家的广告业务,还是美团的首页 rank、头条的 feed 流,背后都有 CTR Prediction 的身影。
下面用一个简化的广告推荐例子,来说明 CTR Prediction 所要解决的问题及其解决思路。
以上的表格中,Item 是将要作为广告展示给用户的内容,我们希望能够根据用户的信息预测用户是否会点击这个广告,从而达到更精确推荐的目的。
我们将用户的相关信息抽象为特征(X),点击行为作为函数的目标(y),那么问题的核心在于,如何尽可能准确地从数据中学习到 X 与 y 的关系 。这可以简化地理解为一个分类问题。其中:y 为是否点击,X 为用户的相关信息。
用来解决 CTR Prediction 问题的模型有很多种,WDL 只是其中(很重要)的一种。
那么,WDL 为什么在 CTR Prediction 模型中如此重要呢?这主要是因为 WDL 模型的特殊历史地位决定的,可以从下图 (图来源:https://zhuanlan.zhihu.com/p/243243145)CTR Prediction 模型演化的历程中看到,WDL 起到了承上启下的作用,可以说,当今落地的 CTR Prediction 模型,都或多或少有 WDL 的影子,它们要么是通过改进 WDL 的 Wide 部分得到,要么是改进 WDL 的 Deep 部分得到,要么是 WDL 的前身。
2.1.2 WDL(Wide & Deep Learning)
上图展示了 Google 团队 WDL 论文中提出的模型的结构,WDL 模型分为 Wide 和 Deep 两部分:
- 单看 Wide 部分,与 Logistic Regression 模型并没有什么区别,就是一个线性模型。
- Deep 部分则是先对类型特征(Categorical Features)做 Embedding,在 Embedding 后接一个由多个隐藏层组成的神经网络,用于学习特征之间的高阶交叉组合关系。
由于 Embedding 机制的引入,WDL 相比于单纯的 Wide 模型有更强的泛化能力。Google 论文中展示了一个具体的例子:
这是一个关于 APP 推荐的例子,WDL 模型的具体网络结构以及输入如下:
- Wide 部分:线性模型部分通常输入稀疏的类别特征进行训练。另外,通过利用交叉特征高效的实现记忆能力,达到准确推荐的目的。比如在这个例子里,选取了两个类别特征(User Installed App 与 Impression App)做叉积变换的结果作为线性部分的输入。
- Deep部分:稀疏、高维的类别特征首先通过 Embeddings 转换为低维稠密向量,然后与连续值特征拼接在一起,作为MLP的输入。
- Wide & Deep联合训练:Wide 部分和 Deep 部分的输出通过加权方式合并到一起,并通过 Logistic Loss 得到最终输出。
注意其中的 Embedding 过程,通常情况下,因 Embedding 而引入的巨大词表(Embedding Table),是 WDL 必须使用分布式的最主要原因。
2.2 词表为什么这么大
本节将介绍:
- WDL 中为什么会有巨大的 Embedding Table?
- 为什么实现模型并行的分布式 WDL 是必需的也是困难的?
2.2.1 从 One-Hot 到 Embedding
WDL 需要 Embeding,虽然 Embedding 已经是 CTR 系统的基本操作,但是名气最大的可能还是词嵌入(word embedding),它也更容易解释。我们先简单介绍 One-Hot 与词嵌入,在下一节将看到,它与 WDL 中采用的 Embedding 没有本质区别。
众所周知,One-Hot 编码是最原始的用来表示字、词的方式。假如能使用的字只有五个:“牛、年、运、气、旺”,那么它们的 One-Hot 编码可以是:
牛: [1, 0, 0, 0, 0]
年: [0, 1, 0, 0, 0]
运: [0, 0, 1, 0, 0]
气: [0, 0, 0, 1, 0]
旺: [0, 0, 0, 0, 1]
“运气”这个词,采用以上 One-Hot 编码,就是:
这太稀疏了,在工程实现中有诸多弊端,于是我们可以准备一个矩阵,利用矩阵乘法:
将2个1×5的稀疏向量,“压缩到”到 2个1×3 的稠密向量中(词向量)。
以上包含有 w i j w_{ij} wij的矩阵,就称为词表(Embedding Table),且由于 One-Hot 编码的特殊性,One-Hot 向量与词表的矩阵乘法,其实相当于是一次“查表”的过程,如上例中,其实就是根据“1”在 One-Hot 向量中的位置(第2列、第3列),从词表中取出对应的向量(第2行、第3行)。
在实际工程中,并不会真正进行 One-Hot 编码(浪费内存),而是将 One-Hot 编码中 “1”的位置作为编号(通常称为 sparse ID),利用 gather 操作,从词表中取出词向量,等价于矩阵乘法。
如,上例的矩阵乘法,在 OneFlow 中用代码表示,则为:
embedding = flow.gather(embedding_table, index, axis=0) # index: [2, 3]
2.2.2 WDL 中的 Embedding
实际场景中,图像识别、语音识别等问题的输入常常具有连续、稠密且空间/时间有良好局部相关性的特点,CTR Prediction 问题则不同,大多数输入都是稀疏、离散、高维的类别特征(Categorical Features),因此通常需要通过 Embedding Table 将这些稀疏的类别特征转换为低维稠密向量。
我们已经知道,对于 Word Embedding,有多少个字(词),Embedding Table 就有多少行;那么,Wide & Deep 中的 Embedding Table 的行数,又是由什么决定的呢?它其实是所有 Categorical Features 做 One-Hot 编码后的维度总和。
我们以找到以下表格对应的词表的大小为例:
手机 | 样式 | 时间段 | 文章新吗 |
---|---|---|---|
小米 | 大图 | 早上 | 新 |
iPhone | 三图 | 下午 | 旧 |
三星 | 中图 | 早上 | 新 |
先对各个类别特征分别做 One-Hot 编码:
手机 | 样式 | 时间段 | 文章新吗 |
---|---|---|---|
(1, 0, 0) | (1, 0, 0) | (1, 0) | (1, 0) |
(0, 1, 0) | (0, 1, 0) | (0, 1) | (0, 1) |
(0, 0, 1) | (0, 0, 1) | (1, 0) | (1, 0) |
理论上,每个类别特征的 One-Hot 编码都应该有对应的 Embeding Table:
手机 | 样式 | 时间段 | 文章新吗 | |
---|---|---|---|---|
(1, 0, 0) | (1, 0, 0) | (1, 0) |