前言
在大规模自然语言处理模型(如 Word2Vec)中,高效计算词的概率分布是训练的关键瓶颈。标准 Softmax 在大词汇表下计算代价高昂,分层 Softmax(Hierarchical Softmax) 通过引入二叉树结构,将时间复杂度从 O(∣V∣)O(|\mathcal{V}|)O(∣V∣) 降至 O(log∣V∣)O(\log |\mathcal{V}|)O(log∣V∣),成为经典优化方案。
🌲 一、背景:为什么要用分层 Softmax?
在词向量建模任务中(如 Skip-gram),我们需要计算上下文词 wow_owo 在给定中心词 wcw_cwc 下的条件概率:
P(wo∣wc) P(w_o \mid w_c) P(wo∣wc)
传统方法使用标准 Softmax 函数:
P(wo∣wc)=exp(uwo⊤vwc)∑w′∈Vexp(uw′⊤vwc) P(w_o \mid w_c) = \frac{\exp(\mathbf{u}_{w_o}^\top \mathbf{v}_{w_c})}{\sum_{w' \in \mathcal{V}} \exp(\mathbf{u}_{w'}^\top \mathbf{v}_{w_c})} P(wo∣wc)=∑w′∈Vexp(uw′⊤vwc)exp(uwo⊤vwc)
其中:
- vwc∈Rd\mathbf{v}_{w_c} \in \mathbb{R}^dvwc∈Rd:中心词 wcw_cwc 的词向量(输入向量);
- uwo∈Rd\mathbf{u}_{w_o} \in \mathbb{R}^duwo∈Rd:上下文词 wow_owo 的输出向量(输出嵌入);
- V\mathcal{V}V:整个词汇表,通常包含数万至数十万个词。
❗ 问题:计算效率低
标准 Softmax 的分母(称为“配分函数”)需要对整个词汇表求和。这意味着:
- 每次前向传播需计算 ∣V∣|\mathcal{V}|∣V∣ 个点积;
- 每次反向传播需更新 ∣V∣|\mathcal{V}|∣V∣ 个梯度;
- 当 ∣V∣=105|\mathcal{V}| = 10^5∣V∣=105 时,单次计算成本极高,无法满足大规模训练需求。
✅ 解决方案:分层 Softmax(Hierarchical Softmax)
核心思想:
将词汇表组织成一棵二叉树,每个词对应一个叶子节点。从根节点到叶子节点的路径唯一确定该词。概率计算转化为路径上一系列二元决策的概率乘积。
优势:
- 时间复杂度从 O(∣V∣)O(|\mathcal{V}|)O(∣V∣) 降至 O(log∣V∣)O(\log |\mathcal{V}|)O(log∣V∣);
- 无需负采样,保持概率归一性;
- 可与 Huffman 编码结合,高频词路径更短,进一步加速。
🌳 二、二叉树结构说明(对应图14.2.1)
在分层 Softmax 中,词汇表被编码为一棵满二叉树(通常为 Huffman 树,以最小化期望路径长度):
- 叶子节点:每个叶节点唯一对应词汇表中的一个词 wiw_iwi。
- 内部节点:不表示任何实际词,仅作为路径上的“决策点”,每个内部节点拥有一个可学习的向量 un\mathbf{u}_nun。
- 路径唯一性:从根节点到任意叶子节点的路径是唯一的,路径长度 L(w)L(w)L(w) 表示该词的编码长度。
📌 示例:词 w3w_3w3 的路径
假设从根节点到 w3w_3w3 的路径为:
根 → n(w₃,1) → n(w₃,2) → n(w₃,3) → w₃
则:
- 路径节点数 L(w3)=5L(w_3) = 5L(w3)=5(包括根和叶子);
- 需要做 L(w3)−1=4L(w_3) - 1 = 4L(w3)−1=4 次左右子节点选择决策。
🔤 三、符号定义
| 符号 | 含义 |
|---|---|
| wow_owo | 上下文词(输出词) |
| wcw_cwc | 中心词(输入词) |
| vwc\mathbf{v}_{w_c}vwc | 中心词向量(输入嵌入) |
| un\mathbf{u}_{n}un | 内部节点 nnn 的输出向量(可训练参数) |
| L(w)L(w)L(w) | 从根到词 www 的路径上的节点总数(含首尾) |
| n(w,j)n(w,j)n(w,j) | 从根到 www 的路径上第 jjj 个节点(j=1j=1j=1 为根) |
| σ(x)\sigma(x)σ(x) | Sigmoid 函数:σ(x)=11+e−x\sigma(x) = \frac{1}{1 + e^{-x}}σ(x)=1+e−x1 |
| leftChild(n)\text{leftChild}(n)leftChild(n) | 节点 nnn 的左子节点 |
| ⟦⋅⟧\llbracket \cdot \rrbracket[[⋅]] | Iverson bracket:条件为真时值为 1,否则为 0 |
✅ 四、公式详解
分层 Softmax 的核心公式如下:
P(wo∣wc)=∏j=1L(wo)−1σ(⟦n(wo,j+1)=leftChild(n(wo,j))⟧⋅un(wo,j)⊤vwc) P(w_o \mid w_c) = \prod_{j=1}^{L(w_o)-1} \sigma\left( \llbracket n(w_o, j+1) = \text{leftChild}(n(w_o, j)) \rrbracket \cdot \mathbf{u}_{n(w_o,j)}^\top \mathbf{v}_{w_c} \right) P(wo∣wc)=j=1∏L(wo)−1σ([[n(wo,j+1)=leftChild(n(wo,j))]]⋅un(wo,j)⊤vwc)
🧩 分步拆解
1️⃣ 乘积形式:路径概率分解
- 概率不是一次性计算,而是路径上每一步选择正确方向的概率的连乘积。
- 总步数 = L(wo)−1L(w_o) - 1L(wo)−1,因为最后一个节点是叶子,无需选择。
2️⃣ 方向判断:Iverson Bracket 控制符号
- 对于路径上第 jjj 个节点 n(wo,j)n(w_o, j)n(wo,j),判断下一个节点 n(wo,j+1)n(w_o, j+1)n(wo,j+1) 是否为其左子节点。
- 若是左子节点 ⇒ Iverson bracket = 1 ⇒ 输入为正:σ(+u⊤v)\sigma(+\mathbf{u}^\top \mathbf{v})σ(+u⊤v)
- 若是右子节点 ⇒ Iverson bracket = 0 ⇒ 输入为负:σ(−u⊤v)=1−σ(u⊤v)\sigma(-\mathbf{u}^\top \mathbf{v}) = 1 - \sigma(\mathbf{u}^\top \mathbf{v})σ(−u⊤v)=1−σ(u⊤v)
💡 关键技巧:
使用 σ(±u⊤v)\sigma(\pm \mathbf{u}^\top \mathbf{v})σ(±u⊤v) 统一表达左右选择概率:
- 左子节点概率 = σ(u⊤v)\sigma(\mathbf{u}^\top \mathbf{v})σ(u⊤v)
- 右子节点概率 = σ(−u⊤v)=1−σ(u⊤v)\sigma(-\mathbf{u}^\top \mathbf{v}) = 1 - \sigma(\mathbf{u}^\top \mathbf{v})σ(−u⊤v)=1−σ(u⊤v)
3️⃣ Sigmoid 建模:局部二元分类器
- 在每个内部节点 n(wo,j)n(w_o, j)n(wo,j),我们训练一个“二元分类器”,决定向左还是向右走。
- 输入特征:中心词向量 vwc\mathbf{v}_{w_c}vwc
- 权重参数:节点向量 un(wo,j)\mathbf{u}_{n(w_o,j)}un(wo,j)
- 输出:选择左子节点的概率 σ(un⊤vwc)\sigma(\mathbf{u}_{n}^\top \mathbf{v}_{w_c})σ(un⊤vwc)
4️⃣ 最终概率:路径连乘
- 只有沿着通向 wow_owo 的唯一路径每一步都“选对方向”,才能最终到达该词。
- 因此,P(wo∣wc)P(w_o \mid w_c)P(wo∣wc) 是路径上所有“正确选择”概率的乘积。
⚙️ 五、训练与实现
🎯 目标函数
最大化对数似然:
L=logP(wo∣wc)=∑j=1L(wo)−1logσ([direction]⋅unj⊤vwc) \mathcal{L} = \log P(w_o \mid w_c) = \sum_{j=1}^{L(w_o)-1} \log \sigma\left( [\text{direction}] \cdot \mathbf{u}_{n_j}^\top \mathbf{v}_{w_c} \right) L=logP(wo∣wc)=j=1∑L(wo)−1logσ([direction]⋅unj⊤vwc)
其中direction由 Iverson bracket 决定符号。
🔄 梯度更新
对每个路径上的节点向量 unj\mathbf{u}_{n_j}unj 和中心词向量 vwc\mathbf{v}_{w_c}vwc,可通过链式法则求导并使用 SGD 更新。
🆚 六、与负采样的对比
| 特性 | 分层 Softmax | 负采样 (Negative Sampling) |
|---|---|---|
| 概率归一性 | ✅ 严格归一 | ❌ 近似,不归一 |
| 计算复杂度 | O(logV)O(\log{\mathcal{V}})O(logV) | O(K)O(K)O(K),K 为负样本数 |
| 实现复杂度 | 需构建与维护树结构 | 简单,只需采样 |
| 适合场景 | 词汇表极大,需精确概率 | 通用,尤其适合中等词汇表 |
| 高频词优化 | 可用 Huffman 树缩短路径 | 依赖采样分布 |
✅ 七、总结
分层 Softmax巧妙地将一个大规模多分类问题转化为一系列二分类问题,通过树形路径上的概率连乘实现高效计算。其优势在于:
- 高效性:计算复杂度从线性降至对数;
- 可学习性:节点向量可通过梯度下降优化;
- 灵活性:可结合 Huffman 编码优化高频词路径;
- 理论优雅:保持概率分布的归一性。
2839

被折叠的 条评论
为什么被折叠?



