前言:
在正文开始之前,首先给大家介绍一个不错的人工智能学习教程:https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程,感兴趣的读者可以自行查阅。
1. 引言
在强化学习中,策略梯度方法通过直接优化策略来最大化累积奖励。传统的策略梯度方法,如REINFORCE,存在高方差和收敛速度慢的问题。为了解决这些问题,Schulman等人提出了近端策略优化算法(Proximal Policy Optimization,PPO),它在更新策略时引入了信赖域约束,既保证了策略的更新幅度不过大,又简化了计算过程,被广泛应用于各种强化学习任务中。
2. 算法原理
PPO算法的核心思想是通过限制新旧策略之间的变化,防止策略更新过度。具体来说,PPO通过以下目标函数来更新策略:
L CLIP ( θ ) = E t [ min ( r t ( θ ) A ^ t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] L^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \ \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] LCLIP(θ)=Et[min(rt(θ)A^t, clip(rt(θ),1−ϵ,1+ϵ)A^t)]
其中:
- r t ( θ ) = π θ ( a t ∣ s t ) π θ old ( a t ∣ s t ) r_t(\theta) = \dfrac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt(θ)=πθold(at∣st)πθ(at∣st) 表示新旧策略的概率比。
- A ^ t \hat{A}_t A^t 是优势函数的估计。
- ϵ \epsilon ϵ 是控制策略更新幅度的超参数。
2.1 优势函数估计
优势函数 A ^ t \hat{A}_t A^t 可以通过广义优势估计(Generalized Advantage Estimation,GAE)来计算:
A ^ t = δ t + ( γ λ ) δ t + 1 + ( γ λ ) 2 δ t + 2 + … \hat{A}_t = \delta_t + (\gamma \lambda) \delta_{t+1} + (\gamma \lambda)^2 \delta_{t+2} + \dots A^t=δt+(γλ)δt+1+(γλ)2δt+2+…
其中,TD残差 δ t \delta_t δt 定义为:
δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)−V(st)
γ \gamma γ 是折扣因子, λ \lambda λ 是用于平衡偏差和方差的超参数。
2.2 策略更新
PPO的策略更新通过最大化 L CLIP ( θ ) L^{\text{CLIP}}(\theta) LCLIP(θ) 来实现。由于引入了 clip \text{clip} clip 操作,损失函数对 r t ( θ ) r_t(\theta) rt(θ) 的变化在 [ 1 − ϵ , 1 + ϵ ] [1 - \epsilon, 1 + \epsilon] [1−ϵ,1+ϵ] 范围之外不再敏感,从而限制了每次更新的步幅。
2.3 价值网络更新
除了策略网络,PPO还使用价值网络来估计状态值函数 V ( s ) V(s) V(s),其损失函数为:
L VF ( θ ) = E t [ ( V θ ( s t ) − V t target ) 2 ] L^{\text{VF}}(\theta) = \mathbb{E}_t \left[ \left( V_\theta(s_t) - V_t^{\text{target}} \right)^2 \right] LVF(θ)=Et[(Vθ(st)−Vttarget)2]
其中, V t target V_t^{\text{target}} Vttarget 是对真实价值的估计,例如使用TD目标:
V t target = r t + γ V ( s t + 1 ) V_t^{\text{target}} = r_t + \gamma V(s_{t+1}) Vttarget=rt+γV(st+1)
2.4 总损失函数
综合考虑策略损失和价值函数损失,以及可能的熵正则项,PPO的总损失函数为:
L ( θ ) = L CLIP ( θ ) − c 1 L VF ( θ ) + c 2 S [ π θ ] ( s t ) L(\theta) = L^{\text{CLIP}}(\theta) - c_1 L^{\text{VF}}(\theta) + c_2 S[\pi_\theta](s_t) L(θ)=LCLIP(θ)−c1LVF(θ)+c2S[πθ](st)
其中:
- c 1 c_1 c1 和 c 2 c_2 c2 是权衡各项损失的系数。
- S [ π θ ] ( s t ) S[\pi_\theta](s_t) S[πθ](st) 是策略的熵,鼓励探索。
3. 案例分析
为了更好地理解PPO算法,我们在经典的CartPole-v1环境上进行了实验。该环境的目标是控制小车移动,以保持竖立的杆子不倒下。
3.1代码实现
以下是PPO算法在CartPole-v1环境上的部分实现代码:
class PPO:
'''PPO算法'''
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,
lmbda, epsilon, epochs, device):
self.action_dim = action_dim
self.actor_critic = ActorCritic(state_dim, hidden_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor_critic.actor_parameters(), lr=actor_lr)
self.critic_optimizer = optim.Adam(self.actor_critic.critic_parameters(), lr=critic_lr)
self.gamma = gamma # 折扣因子
self.lmbda = lmbda # GAE参数
self.epsilon = epsilon # PPO截断范围
self.epochs = epochs # PPO的更新次数
self.device = device
def take_action(self, state):
'''根据策略网络选择动作'''
state = torch.tensor([state], dtype=torch.float).to(self.device)
with torch.no_grad():
action_probs, _ = self.actor_critic(state)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
return action.item()
def update(self, transition_dict):
'''更新策略网络和价值网络'''
states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1).to(self.device)
rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
# 计算TD误差和优势函数
_, state_values = self.actor_critic(states)
_, next_state_values = self.actor_critic(next_states)
td_target = rewards + self.gamma * next_state_values * (1 - dones)
delta = td_target - state_values
delta = delta.detach().cpu().numpy()
# Generalized Advantage Estimation (GAE)
advantage_list = []
advantage = 0.0
for delta_t in delta[::-1]:
advantage = self.gamma * self.lmbda * advantage + delta_t[0]
advantage_list.append([advantage])
advantage_list.reverse()
advantages = torch.tensor(advantage_list, dtype=torch.float).to(self.device)
# 计算旧策略的log概率
with torch.no_grad():
action_probs_old, _ = self.actor_critic(states)
dist_old = torch.distributions.Categorical(action_probs_old)
log_probs_old = dist_old.log_prob(actions)
# 更新策略网络和价值网络
for _ in range(self.epochs):
action_probs, state_values = self.actor_critic(states)
dist = torch.distributions.Categorical(action_probs)
log_probs = dist.log_prob(actions)
ratio = torch.exp(log_probs - log_probs_old)
surr1 = ratio * advantages.squeeze()
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages.squeeze()
actor_loss = -torch.mean(torch.min(surr1, surr2))
critic_loss = F.mse_loss(state_values, td_target.detach())
# 更新策略网络
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新价值网络
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
3.2 结果分析
Iteration 1: 100%|██████████| 30/30 [00:00<00:00, 66.19it/s, Episode=30/300, Average Return=10.00]
Iteration 2: 100%|██████████| 30/30 [00:00<00:00, 36.67it/s, Episode=60/300, Average Return=162.90]
Iteration 3: 100%|██████████| 30/30 [00:01<00:00, 24.94it/s, Episode=90/300, Average Return=278.70]
Iteration 4: 100%|██████████| 30/30 [00:01<00:00, 19.59it/s, Episode=120/300, Average Return=287.80]
Iteration 5: 100%|██████████| 30/30 [00:01<00:00, 17.57it/s, Episode=150/300, Average Return=240.70]
Iteration 6: 100%|██████████| 30/30 [00:01<00:00, 21.10it/s, Episode=180/300, Average Return=354.60]
Iteration 7: 100%|██████████| 30/30 [00:02<00:00, 12.90it/s, Episode=210/300, Average Return=450.50]
Iteration 8: 100%|██████████| 30/30 [00:02<00:00, 11.59it/s, Episode=240/300, Average Return=500.00]
Iteration 9: 100%|██████████| 30/30 [00:02<00:00, 11.52it/s, Episode=270/300, Average Return=475.50]
Iteration 10: 100%|██████████| 30/30 [00:02<00:00, 11.31it/s, Episode=300/300, Average Return=500.00]
运行上述代码,可以观察到在训练过程中,智能体的平均回报逐渐提高,最终稳定在较高水平。这表明PPO算法有效地学习到了保持杆子平衡的策略。
从学习曲线可以看出,经过大约200个回合的训练,智能体的表现达到了环境的最高分。这验证了PPO算法在处理连续动作空间和策略优化问题上的有效性。
4. 总结
PPO算法通过引入概率比率的截断和优势函数的估计,实现了高效稳定的策略更新。在CartPole-v1环境上的实验表明,PPO能够快速收敛到最优策略,具有较好的性能和稳定性。由于其简单高效的特点,PPO在强化学习领域得到了广泛的应用和认可。