文章目录
前言
提示:这里可以添加本文要记录的大概内容:
本章介绍REINFORCE with Baseline的搭建以及A2C的搭建。
一、Baseline
1.公式回顾:
(1)状态价值函数:Vπ(s) = IEA~π[Qπ(s,A)] = Σaπ(a|s;θ)·Qπ(s,a)
它是关于动作价值函数的期望,A是随机变量,π(a|s;θ)是策略网络
(2)策略梯度:
∂
V
(
s
)
∂
θ
\frac{∂V(s)}{∂θ}
∂θ∂V(s) = IEA~π[
∂
l
n
π
(
A
∣
s
;
θ
)
∂
θ
\frac{∂lnπ(A|s;θ)}{∂θ}
∂θ∂lnπ(A∣s;θ) · Qπ(s,A)]
2.定义Baseline
1.设Baseline为b,b不依赖于动作A,则有如下结论:
IEA~π[b ·
∂
l
n
π
(
A
∣
s
;
θ
)
∂
θ
\frac{∂lnπ(A|s;θ)}{∂θ}
∂θ∂lnπ(A∣s;θ)] = 0
2.策略梯度可表示为:
∂ V ( s ) ∂ θ \frac{∂V(s)}{∂θ} ∂θ∂V(s)
= IEA~π[ ∂ l n π ( A ∣ s ; θ ) ∂ θ \frac{∂lnπ(A|s;θ)}{∂θ} ∂θ∂lnπ(A∣s;θ) · Qπ(s,A)] - IEA~π[b · ∂ l n π ( A ∣ s ; θ ) ∂ θ \frac{∂lnπ(A|s;θ)}{∂θ} ∂θ∂lnπ(A∣s;θ)]
= IEA~π[
∂
l
n
π
(
A
∣
s
;
θ
)
∂
θ
\frac{∂lnπ(A|s;θ)}{∂θ}
∂θ∂lnπ(A∣s;θ) · (Qπ(s,A) - b)]
b不影响期望,但是会影响蒙特卡洛对期望的近似,使蒙特卡洛方差降低,收敛更快。
3.常见的Baseline
1.b = Vπ(st)
合理性:
·st是先于at被观测到的,因此不依赖于At
·根据定义,Vπ是对Qπ的期望,很接近Qπ
二、Reinforce with Baseline
1.策略梯度的近似
1.g(at) =
∂
l
n
π
(
a
t
∣
s
t
;
θ
)
∂
θ
\frac{∂lnπ(a_t|s_t;θ)}{∂θ}
∂θ∂lnπ(at∣st;θ) · (Qπ(st,at) - Vπ(st))
无法计算Qπ(st,at)和Vπ(st),需要用蒙特卡洛近似
2.算法REINFORCE近似Qπ:
(1) 观测Agent的动作状态轨迹:st,at,rt,st+1,at+1,rt+1…
(2)ut = Σ ni=t(γi-t * ri),ut就是Qπ的无偏估计
3.神经网络v(s;w)近似Vπ
4.策略梯度的近似结果:
g(at) =
∂
l
n
π
(
a
t
∣
s
t
;
θ
)
∂
θ
\frac{∂lnπ(a_t|s_t;θ)}{∂θ}
∂θ∂lnπ(at∣st;θ) · (ut - v(s;w))
2.搭建训练需要的神经网络
2.1、策略网络π(a|s;θ)结构
(1)输入:状态s
(2)中间层:卷积层处理s得到一个特征向量,全连接层将向量映射到某个向量上
(3)输出:最后经过softmax函数输出动作的概率
2.2、价值网络v(s;w)结构
(1)输入:状态s
(2)中间层:卷积层和全连接层
(3)输出:Baseline
2.3、训练两个网络
令-δt = ut - v(s;w)
(1)训练策略网络: θ = θ + β · g(at) = θ - β · δt ·
∂
l
n
π
(
a
t
∣
s
t
;
θ
)
∂
θ
\frac{∂lnπ(a_t|s_t;θ)}{∂θ}
∂θ∂lnπ(at∣st;θ)
(2)训练价值网络:
损失:δt = v(s;w) - ut
梯度:
∂
δ
t
2
/
2
∂
w
\frac{∂δ^2_t/2}{∂w}
∂w∂δt2/2 = δt ·
∂
v
(
s
t
;
w
)
∂
w
\frac{∂v(s_t;w)}{∂w}
∂w∂v(st;w)
梯度下降:w = w - α · δt ·
∂
v
(
s
t
;
w
)
∂
w
\frac{∂v(s_t;w)}{∂w}
∂w∂v(st;w)
三、Advantage Actor-Critic(A2C)
将Baseline用在Actor-Critic上就得到了Advantage Actor-Critic
1.A2C的神经网络结构
(1)策略网络(actor):π(a|s;θ),用来近似策略函数π(a|s),控制agent作出动作。
(2)价值网络(critic):v(s;w),用来近似状态价值函数Vπ(s),评价状态的好坏。
2.训练神经网络
(1)观测到一条transition(st,at,rt,st+1)
(2)计算TD target:yt = rt + γ · v(st+1;w)
(3)计算TD error:δt = v(st;w) - yt
(4)更新策略网络参数θ:
θ = θ - β · δt
∂
l
n
π
(
a
t
∣
s
t
;
θ
)
∂
θ
\frac{∂lnπ(a_t|s_t;θ)}{∂θ}
∂θ∂lnπ(at∣st;θ)
(5)更新价值网络参数w:
w = w - α · δt ·
∂
v
(
s
t
;
w
)
∂
w
\frac{∂v(s_t;w)}{∂w}
∂w∂v(st;w)