NeurIPS 2021
paper
Introduction
离线强化学习对于OOD数据容易出现外推误差,主流方法通过限制策略以及保守的Q价值估计予以解决,但前者受限于数据集质量的影响,后者则容易导致保守的值估计。
本文则是提高OOD数据的不确定性估计,通过Clip Q-learning进行价值函数层面的惩罚。本文首先提出SAC-N算法,发现简单的增加Q函数便可以提高算法sample- efficiency,但是过多的Q导致计算效率较低,且Q网络出现冗余,因此进一步提出了EDAC。
Method
2.1 SAC-N
文章首先回顾以往offline RL对OOD数据,为了防止Q估计过高所采取的办法,进而指出这些方法忽视了Clipped Q-learning的重要性。为了说明该观点,文章提出SAC-N,即将SAC中Q价值函数由2增加到N,
min
ϕ
i
E
s
,
a
,
s
′
∼
D
[
(
Q
ϕ
i
(
s
,
a
)
−
(
r
(
s
,
a
)
+
γ
E
a
′
∼
π
0
(
s
′
)
[
min
j
=
1
,
…
,
N
Q
ϕ
j
′
(
s
′
,
a
′
)
−
β
log
π
θ
(
a
′
∣
s
′
)
]
)
)
2
]
\min_{\phi_i}\mathbb{E}_{\mathbf{s},\mathbf{a},\mathbf{s'}\sim\mathcal{D}}\left[\left(Q_{\phi_i}(\mathbf{s},\mathbf{a})-\left(r(\mathbf{s},\mathbf{a})+\gamma\mathbb{E}_{\mathbf{a'}\sim\pi_0(\mathbf{s'})}\left[\min_{j=1,\ldots,N}Q_{\phi_j^{\prime}}\left(\mathbf{s'},\mathbf{a'}\right)-\beta\log\pi_\theta\left(\mathbf{a'}\mid\mathbf{s'}\right)\right]\right)\right)^2\right]
ϕiminEs,a,s′∼D[(Qϕi(s,a)−(r(s,a)+γEa′∼π0(s′)[j=1,…,NminQϕj′(s′,a′)−βlogπθ(a′∣s′)]))2]
max
θ
E
s
∼
D
,
a
∼
π
θ
(
⋅
∣
s
)
[
min
j
=
1
,
…
,
N
Q
ϕ
j
(
s
,
a
)
−
β
log
π
θ
(
a
∣
s
)
]
\max_{\theta}\mathbb{E}_{\mathbf{s}\sim\mathcal{D},\mathbf{a}\sim\pi_\theta(\cdot|\mathbf{s})}\left[\min_{j=1,\ldots,N}Q_{\phi_j}(\mathbf{s},\mathbf{a})-\beta\log\pi_\theta(\mathbf{a}\mid\mathbf{s})\right]
θmaxEs∼D,a∼πθ(⋅∣s)[j=1,…,NminQϕj(s,a)−βlogπθ(a∣s)]
实验结果也证明,简单的Q-ensemble + Clipped Q也能实现较好的效果
文章分析选择Q-ensemble中的最小值,变向惩罚那些高方差Q值的状态动作对,从而鼓励策略偏爱数据集中的动作。因为上式的bellman残差,明确的将样本数据的Q值进行对齐,使得数据集中的样本比OOD样本的方差更低,该差异性称为认知不确定性(epistemic uncertainty)。
clipped Q值与考虑Q值估计的置信界限密切相关。online RL方法通过在Q-ensemble均值上加标准差的方式来对Q值乐观地估计,即置信上界(upper-confidence bound,UCB),帮助鼓励探索具有高不确定性的未知动作。
而offline RL中数据集是固定的,只需要考虑离线数据,为了防止乐观估计导致Q过高,因此采用Q值估计的置信下界(lower-confidence bound,LCB),例如在Q-ensemble均值上减标准差,这样可以避免高风险的状态动作对。
clipped Q-learning算法选择最差情况,而非悲观估计,也可看作是LCB估计。假设Q满足为均值m(s,a),方差为
σ
(
s
,
a
)
\sigma(s,a)
σ(s,a)的高斯分布,则集成Q的最小期望为:
E
[
min
j
=
1
,
…
,
N
Q
j
(
s
,
a
)
]
≈
m
(
s
,
a
)
−
Φ
−
1
(
N
−
π
8
N
−
π
4
+
1
)
σ
(
s
,
a
)
\mathbb{E}\left[\min_{j=1,\ldots,N}Q_j(\mathrm{s},\mathrm{a})\right]\approx m(\mathrm{s},\mathrm{a})-\Phi^{-1}\left(\frac{N-\frac{\pi}{8}}{N-\frac{\pi}{4}+1}\right)\sigma(\mathrm{s},\mathrm{a})
E[j=1,…,NminQj(s,a)]≈m(s,a)−Φ−1(N−4π+1N−8π)σ(s,a)
其中 Φ 是标准高斯分布的 CDF。这种关系表明,使用Clipped Q 值类似于惩罚 Q 值的集合平均值,标准偏差由依赖于 N 的系数缩放。
进一步实验证明clipped的有效性
(a)表明对ID与OOD数据不同的惩罚力度
(b)OOD高方差且高惩罚力度,ID低方差且低惩罚力度。方差与惩罚力度正相关
(c)随着N的增加,OOD的惩罚力度增加,
2.2 EDAC
SAC-N虽好,但是高N值导致计算效率低,且存在结构性冗余。因此,EDAC希望减少N。
这部分文章从Q网络梯度的角度出发,当 Q 函数共享相似的局部结构时,学习策略的性能会显着下降,即网络梯度的余弦相似度与网络性能呈现负相关
文章指出网络梯度对齐现象严重(类比于结构相同,出现冗余),导致对相近分布的数据点惩罚不足,进而需要更多的Q网络。如下图所示
(a)中的三个Q对(s,a)梯度方差较小,导致在
w
2
w_2
w2方向上的OOD动作
a
+
k
w
2
a+kw_2
a+kw2惩罚力度小,
而(b)中,由于Q梯度的多样性,避免该情况的出现。
根据理论推导,对(s,a附近数据的Q方差进行一阶泰勒展开,那么沿着w方向的样本Q方差可表示为:
V
a
r
(
Q
ϕ
j
(
s
,
a
+
k
w
)
)
≈
V
a
r
(
Q
ϕ
j
(
s
,
a
)
+
k
⟨
w
,
∇
a
Q
ϕ
j
(
s
,
a
)
⟩
)
=
V
a
r
(
Q
(
s
,
a
)
+
k
<
w
,
∇
a
Q
ϕ
j
(
s
,
a
)
>
)
=
k
2
V
a
r
(
⟨
w
,
∇
a
Q
ϕ
j
(
s
,
a
)
⟩
)
=
k
2
w
⊺
V
a
r
(
∇
a
Q
ϕ
j
(
s
,
a
)
)
w
,
\begin{aligned} \mathrm{Var}\left(Q_{\phi_{j}}(\mathbf{s},\mathbf{a}+k\mathbf{w})\right)& \approx\mathrm{Var}\left(Q_{\phi_{j}}(\mathbf{s},\mathbf{a})+k\left\langle\mathbf{w},\nabla_{\mathbf{a}}Q_{\phi_{j}}(\mathbf{s},\mathbf{a})\right\rangle\right) \\ &=\mathrm{Var}\left(Q(\mathbf{s},\mathbf{a})+k\left<\mathbf{w},\nabla_\mathbf{a}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right>\right) \\ &=k^2\mathrm{Var}\left(\left\langle\mathbf{w},\nabla_\mathbf{a}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right\rangle\right) \\ &=k^2\mathbf{w}^\intercal\mathrm{Var}\left(\nabla_\mathbf{a}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right)\mathbf{w}, \end{aligned}
Var(Qϕj(s,a+kw))≈Var(Qϕj(s,a)+k⟨w,∇aQϕj(s,a)⟩)=Var(Q(s,a)+k⟨w,∇aQϕj(s,a)⟩)=k2Var(⟨w,∇aQϕj(s,a)⟩)=k2w⊺Var(∇aQϕj(s,a))w,
其中,w表示归一化的特征向量。
Lemma1:方差矩阵
V
a
r
(
∇
a
Q
ϕ
j
(
s
,
a
)
)
\mathrm{Var}\left(\nabla_{\mathbf{a}}Q_{\phi_{j}}(\mathbf{s},\mathbf{a})\right)
Var(∇aQϕj(s,a))的总方差等于
1
−
1
N
∑
j
=
1
N
∇
a
Q
ϕ
j
(
s
,
a
)
1-\frac1N\sum_{j=1}^N\nabla_\mathbf{a}Q_{\phi_j}(\mathbf{s},\mathbf{a})
1−N1∑j=1N∇aQϕj(s,a)。
同时,上述等式存在一个下界,设置
λ
m
i
n
\lambda_{\mathrm{min}}
λmin为方差矩阵最小特征值:
V
a
r
(
Q
ϕ
j
(
s
,
a
+
k
w
)
)
≈
k
2
w
⊺
V
a
r
(
∇
a
Q
ϕ
j
(
s
,
a
)
)
w
≥
k
2
w
m
i
n
⊺
V
a
r
(
∇
a
Q
ϕ
j
(
s
,
a
)
)
w
m
i
n
=
k
2
λ
m
i
n
.
\begin{aligned} \mathrm{Var}\left(Q_{\phi_{j}}(\mathbf{s},\mathbf{a}+k\mathbf{w})\right)& \approx k^{2}\mathbf{w}^{\intercal}\mathrm{Var}\left(\nabla_{\mathbf{a}}Q_{\phi_{j}}(\mathbf{s},\mathbf{a})\right)\mathbf{w} \\ &\geq k^2\mathbf{w}_{\mathrm{min}}^\intercal\mathrm{Var}\left(\nabla_{\mathbf{a}}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right)\mathbf{w}_{\mathrm{min}} \\ &=k^{2}\lambda_{\mathrm{min}}. \end{aligned}
Var(Qϕj(s,a+kw))≈k2w⊺Var(∇aQϕj(s,a))w≥k2wmin⊺Var(∇aQϕj(s,a))wmin=k2λmin.
增加Q网络的多样性,即增加Q方差,即最大化
λ
m
i
n
\lambda_{\mathrm{min}}
λmin, 由lemma1可知, 最大化
λ
m
i
n
\lambda_{\mathrm{min}}
λmin等价于最小化平均梯度:
minimize
ϕ
E
s
,
a
∼
D
[
⟨
1
N
∑
i
=
1
N
∇
a
Q
ϕ
i
(
s
,
a
)
,
1
N
∑
j
=
1
N
∇
a
Q
ϕ
j
(
s
,
a
)
⟩
]
\underset{\phi}{\operatorname*{minimize}}\left.\mathbb{E}_{\mathbf{s},\mathbf{a}\sim\mathcal{D}}\left[\left\langle\frac{1}{N}\sum_{i=1}^{N}\nabla_{\mathbf{a}}Q_{\phi_i}(\mathbf{s},\mathbf{a}),\frac{1}{N}\sum_{j=1}^{N}\nabla_{\mathbf{a}}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right\rangle\right]\right.
ϕminimizeEs,a∼D[⟨N1i=1∑N∇aQϕi(s,a),N1j=1∑N∇aQϕj(s,a)⟩]
进一步表达为实现Q梯度多样化的损失函数:
minimize
ϕ
J
E
S
(
Q
ϕ
)
:
=
E
s
,
a
∼
D
[
1
N
−
1
∑
1
≤
i
≠
j
≤
N
⟨
∇
a
Q
ϕ
i
(
s
,
a
)
,
∇
a
Q
ϕ
j
(
s
,
a
)
⟩
⏟
E
S
ϕ
i
,
ϕ
j
(
s
,
a
)
]
\underset{\phi}{\operatorname*{minimize}}J_{\mathrm{ES}}(Q_\phi):=\mathbb{E}_{\mathbf{s},\mathbf{a}\sim\mathcal{D}}\left[\frac{1}{N-1}\sum_{1\le i\ne j\le N}\underbrace{\left\langle\nabla_{\mathbf{a}}Q_{\phi_i}(\mathbf{s},\mathbf{a}),\nabla_{\mathbf{a}}Q_{\phi_j}(\mathbf{s},\mathbf{a})\right\rangle}_{\mathrm{ES}_{\phi_i,\phi_j}(\mathbf{s},\mathbf{a})}\right]
ϕminimizeJES(Qϕ):=Es,a∼D
N−111≤i=j≤N∑ESϕi,ϕj(s,a)
⟨∇aQϕi(s,a),∇aQϕj(s,a)⟩
可以看出上式采用梯度余弦相似度衡量Q网络的相似程度,算法伪代码如下,蓝色部分时针对SAC的修改,特别在第五步中最小化Q网络梯度余弦相似度