NIPS 2022
paper
code
可看作是一种数据增强
Intro
最近的一些工作通过将离线 RL 视为一种通用的序列生成问题,并采用诸如 Transformer 架构的序列模型来模拟轨迹上的分布。然而,一般离线 RL 任务中使用的训练数据集通常非常有限,并经常因分布覆盖不足而受到影响,这可能对训练序列生成模型不利。
为此,作者提出了 Bootstrapped Transformer 算法,该算法结合了自举(bootstrapping)的思想,利用学习到的模型生成更多的离线数据,以进一步增强序列模型的训练。通过在两个离线 RL 基准上的广泛实验,作者证明了他们的模型可以大幅弥补现有的离线 RL 训练限制,并超越其他强基线方法。
Method
首先根据序列数据添加累计回报
R
t
=
∑
t
′
=
t
T
γ
t
′
−
t
r
t
′
R_{t}=\sum_{t^{\prime}=t}^{T}\gamma^{t^{\prime}-t}r_{t^{\prime}}
Rt=∑t′=tTγt′−trt′, 并对原始数据处理成离散的token space
τ
=
τ
dis
=
(
…
,
s
t
1
,
s
t
2
,
…
,
s
t
N
,
a
t
1
,
a
t
2
,
…
,
a
t
M
,
r
t
,
R
t
,
…
)
.
\tau =\tau_{\text{dis}}=\begin{pmatrix}\ldots,s_t^1,s_t^2,\ldots,s_t^N,a_t^1,a_t^2,\ldots,a_t^M,r_t,R_t,\ldots\end{pmatrix}.
τ=τdis=(…,st1,st2,…,stN,at1,at2,…,atM,rt,Rt,…).
接下来采用TT架构,通过最大化似然函数
L
\mathcal{L}
L优化序列模型
log
P
θ
(
τ
t
∣
τ
<
t
)
=
∑
i
=
1
N
log
P
θ
(
s
t
i
∣
s
t
<
i
,
τ
<
t
)
+
∑
j
=
1
M
log
P
θ
(
a
t
j
∣
a
t
<
j
,
s
t
,
τ
<
t
)
+
log
P
θ
(
r
t
∣
a
t
,
s
t
,
τ
<
t
)
+
log
P
θ
(
R
t
∣
r
t
,
a
t
,
s
t
,
τ
<
t
)
L
(
τ
)
=
∑
t
=
1
T
log
P
θ
(
τ
t
∣
τ
<
t
)
,
\begin{aligned} \log P_{\theta}(\tau_{t}|\tau_{<t})& =\sum_{i=1}^{N}\log P_{\theta}(s_{t}^{i}|s_{t}^{<i},\tau_{<t})+\sum_{j=1}^{M}\log P_{\theta}(a_{t}^{j}|a_{t}^{<j},s_{t},\tau_{<t}) \\ &+\log P_\theta(r_t|\boldsymbol{a}_t,s_t,\tau_{<t})+\log P_\theta(R_t|r_t,\boldsymbol{a}_t,s_t,\tau_{<t})\\ \mathcal{L}(\tau)&=\sum_{t=1}^T\log P_\theta(\tau_t|\tau_{<t}), \end{aligned}
logPθ(τt∣τ<t)L(τ)=i=1∑NlogPθ(sti∣st<i,τ<t)+j=1∑MlogPθ(atj∣at<j,st,τ<t)+logPθ(rt∣at,st,τ<t)+logPθ(Rt∣rt,at,st,τ<t)=t=1∑TlogPθ(τt∣τ<t),
Trajectory Generation
接下俩便是利用模型生成轨迹实现数据增强。文章这里提出了两种增强方法
-
Autoregressive generation: y ~ n ∼ P θ ( y n ∣ y ~ < n , τ ≤ T − T ′ ) \tilde{y}_n\sim P_\theta\left(y_n|\tilde{y}_{<n},\tau_{\leq T-T'}\right) y~n∼Pθ(yn∣y~<n,τ≤T−T′)。通过自回归的方法生成各个token。

-
Teacher-forcing generation: y ~ n ∼ P θ ( y n ∣ y < n , τ ≤ T − T ′ ) \tilde{y}_n\sim P_\theta\left(y_n|{y}_{<n},\tau_{\leq T-T'}\right) y~n∼Pθ(yn∣y<n,τ≤T−T′),正常序列生成

增强后的tokens数据将于原始数据concatenate共同训练序列模型
为了防止由于训练数据不准确而导致的累积学习偏差,根据生成百分比 η% 选择每批中置信度分数最高的部分轨迹。置信度定义为所有生成的令牌的平均对数概率为:
c
(
τ
)
=
1
T
′
(
N
+
M
+
2
)
∑
t
=
T
−
T
′
+
1
T
log
P
θ
(
τ
t
∣
τ
<
t
)
c(\tau)=\frac{1}{T'(N+M+2)}\sum_{t=T-T'+1}^T\log P_\theta(\tau_t|\tau_{<t})
c(τ)=T′(N+M+2)1t=T−T′+1∑TlogPθ(τt∣τ<t)
伪代码

算法提出两种Bootstrapp形式。这两种方法都将首先在原始的离线轨迹上训练序列模型,然后利用它生成新的轨迹。但是对增强数据的利用方式不同
- Boot-o: 将再次在生成的轨迹上训练模型,并在使用它们后立即丢弃它们
- Boot-r:将生成的轨迹附加到原始数据集中,生成的轨迹将在附加到数据集之后的每个时期使用。
实验结果看第一种结合自回归的增强稍微好点
results
对比其他Offline的方法

对比两种增强方法


3983

被折叠的 条评论
为什么被折叠?



