ICML 2024
paper
code
解决离线到在线过程中,判别器对齐问题。
Intro
传统采用离线模仿学习结合基于GAIL的在线模仿学习,容易因为在线阶段初始化的判别器表现具有随机性,与离线获得的策略不一致。因此,本文提出的OLLIE,便是利用混合质量的数据,实现判别器与策略之间的对齐,从何防止O2O的performance drop。
Method
GAIL
GAIL是一种传统在线模仿学习算法,其目标是对抗的学习一个判别器,用于区分专家数据与在线数据
min π max D E ρ π [ log D ( s , a ) ] + E ρ ˉ e [ log ( 1 − D ( s , a ) ) ] . ( 2 ) \min_\pi\max_D\mathbb{E}_{\rho^\pi}[\log D(s,a)]+\mathbb{E}_{\bar{\rho}^e}[\log(1-D(s,a))].(2) πminDmaxEρπ[logD(s,a)]+Eρˉe[log(1−D(s,a))].(2).
最优判别器输出表示为 D ∗ ( s , a ) = ρ π ( s , a ) ρ π ( s , a ) + ρ ~ e ( s , a ) . D^*(s,a)=\frac{\rho^\pi(s,a)}{\rho^\pi(s,a)+\tilde{\rho}^e(s,a)}. D∗(s,a)=ρπ(s,a)+ρ~e(s,a)ρπ(s,a).
Offline IL
假设专家数据 D e D_e De以及混合数据 D o ≐ D e ∪ D s \mathcal{D}_{o}\doteq\mathcal{D}_{e}\cup\mathcal{D}_{s} Do≐De∪Ds的状态动作分布分别为 ρ ~ e > 0 \tilde{\rho}^{e}>0 ρ~e>0以及 ρ ~ o > 0 \tilde{\rho}^o>0 ρ~o>0。离线学习的目标可以看作状态动作分布匹配问题,通过一个逆KL散度表达
min π D K L ( ρ π ∥ ρ ~ e ) = E ( s , a ) ∼ ρ π [ log ρ π ( s , a ) ρ ~ e ( s , a ) ] \operatorname*{min}_{\pi}D_{\mathrm{KL}}(\rho^{\pi}\|\tilde{\rho}^{e})=\mathbb{E}_{(s,a)\sim\rho^{\pi}}\left[\log{\frac{\rho^{\pi}(s,a)}{\tilde{\rho}^{e}(s,a)}}\right] πminDKL(ρπ∥ρ~e)=E(s,a)∼ρπ[logρ~e(s,a)ρπ(s,a)]
为了让混合数据参与上式的处理过程,在对数项的分子分母同时添加 ρ ~ o \tilde{\rho}^o ρ~o,原问题改造为
max π E ( s , a ) ∼ ρ π [ R ~ ( s , a ) ] − D K L ( ρ π ∥ ρ ~ o ) ( 5 ) \max_\pi\mathbb{E}_{(s,a)\sim\rho^\pi}\big[\tilde{R}(s,a)\big]-D_{\mathrm{KL}}(\rho^\pi\|\tilde{\rho}^o)~~~~(5) πmaxE(s,a)∼ρπ[R~(s,a)]−DKL(ρπ∥ρ~o) (5)
其中 R ~ ( s , a ) ≐ log ρ ˉ e ( s , a ) ρ ˉ o ( s , a ) \tilde{R}(s,a)\doteq\log\frac{\bar{\rho}^e(s,a)}{\bar{\rho}^o(s,a)} R~(s,a)≐logρˉo(s,a)ρˉe(s,a)。对于低维表格环境可以通过统计计算状态动作分布,而对于高维环境则可以通过训练一个判别器进行区分
max d E ρ ~ e [ log d ( s , a ) ] + E ρ ~ o [ log ( 1 − d ( s , a ) ) ] ( 6 ) \max_d\mathbb{E}_{\tilde{\rho}^e}\big[\log d(s,a)\big]+\mathbb{E}_{\tilde{\rho}^o}\big[\log(1-d(s,a))\big]\quad(6) dmaxEρ~e[logd(s,a)]+Eρ~o[log(1−d(s,a))](6)
进而得到 R ~ ( s , a ) = log ρ ~ e ( s , a ) ρ ~ o ( s , a ) = log d ∗ ( s , a ) 1 − d ∗ ( s , a ) . ( 7 ) \tilde{R}(s,a)=\log\frac{\tilde{\rho}^{e}(s,a)}{\tilde{\rho}^{o}(s,a)}=\log\frac{d^{*}(s,a)}{1-d^{*}(s,a)}.\quad(7) R~(s,a)=logρ~o(s,a)ρ~e(s,a)=log1−d∗(s,a)d

最低0.47元/天 解锁文章

398

被折叠的 条评论
为什么被折叠?



