location:beijing
涉及知识:大模型压缩、知识蒸馏

1. 核心内容
本文提出在一个贝叶斯估计框架内估计闭源语言模型的输出分布,包括先验估计和后验估计。先验估计的目的是通过闭源模型生成的语料库(可能包含模型的粗粒度信息)得到先验分布;后验估计使用代理模型来更新先验分布并生成后验分布。利用这两个分布来进行知识蒸馏。
2. 方法
该文章的创新点是在知识蒸馏的过程中,使用一个代理模型作为教师模型和学生模型的中介,该项目配置如Table. 1
| 项目 | 方法 |
|---|---|
| benchmarks | BBH\ARC\AGIEval\MMLU\CSQA\GSM8K\ |
| teacher model | GPT-4 |
| proxy model | LLaMA-33B |
| student model | LLaMA-7B/13B |
一些参数表示如下表
| 变量 | 含义 |
|---|---|
| T \mathcal{T} T | 闭源的教师模型 |
| S \mathcal{S} S | 学生模型 |
| M \mathcal{M} M | 开源的代理模型 |
| X X X | 输入的token序列 |
| Y Y Y | 输出的token序列 |
| p Y t p_{Y_t} pYt | T \mathcal{T} T输出的概率Pr ( Y t ( Y_{t} (Yt | X , Y < t ) X, Y_{< t}) X,Y<t) |
| q Y t q_{Y_t} qYt | S \mathcal{S} S输出的概率Pr ( Y t (Y_{t} (Yt | X , Y < t ) X,Y_{<t}) X,Y<t) |
| P Y t P_{Y_t} PYt | 与 p Y t p_{Y_t} pYt相关的离散随机变量 |
用指示函数 I Y t = w \mathbb{I}_{Y_t=\boldsymbol{w}} IYt=w(其实不是空心的I应该是空心的1,没法在优快云打出来)表示 T \mathcal{T} T在 t t t时刻产生的one-hot编码标签。
传统的目标函数可以表示为
L t traditional = − ∑ w ∈ V I Y t = w log q Y t = w + ∑ w ∈ V p Y t = w log p Y t = w q Y t = w (1) \mathcal{L}_{t}^{\text{traditional}}=-\sum_{w\in\mathbb{V}}\mathbb{I}_{Y_{t}=w}\log q_{Y_{t}=w}+\sum_{w\in\mathbb{V}}p_{Y_{t}=w}\log\frac{p_{Y_{t}=w}}{q_{Y_{t}=w}}\tag{1} Lttraditional=−w∈V∑IYt=wlogqYt=w+w∈V∑pYt=wlogqYt=wpYt=w(1)式中 V \mathbb{V} V表示词典, w w w是词典中的一个token,可以看出, L t traditional \mathcal{L}_{t}^{\text{traditional}} Lttraditional由两部分组成,第一部分表示由硬标签(Fig.2)产出的交叉熵损失(交叉熵与相对熵在第三章详细说明),第二部分表示用软标签计算出的KL损失,一般情况下由于 p Y t p_{Y_{t}} pYt很难得到,第二项是被忽略的。

这篇论文就是解决第二项的问题。
2.1 先验估计
先验估计的目的是使用 T \mathcal{T} T生成的语料库 C \mathcal{C} C,得到每一步 t t t的近似 p Y t p_{Y_{t}} pYt的粗粒度估计 p ^ Y t \hat{p}_{Y_t} p^Yt,来自改良的n-gram算法(基于第n个项目的出现只与前面n-1个项目有关)来实现,对于给定一个输出token序列 Y ≤ t ∈ C Y_{\leq t}\in\mathcal{C} Y≤t∈C,假设 Y t = w t Y_{t}=w_t Yt=wt其中 w t w_t wt是 V \mathbb{V} V中的一个token,对于 V \mathbb{V} V中的某个token w w w如果有 w = w t w=w_t w=wt,有
p ^ Y t = w = # ( Y t = w , Y t − 1 = w t − 1 , … , Y t − n = w t − n ) γ # ( Y t − 1 = w t − 1 , … , Y t − n = w t − n ) + γ − 1 γ (2) \hat{p}_{Y_t=w}=\frac{\#(Y_t=w,Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}{\gamma\#(Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}+\frac{\gamma-1}{\gamma}\tag{2} p^Yt=w=γ#(Y

最低0.47元/天 解锁文章

被折叠的 条评论
为什么被折叠?



