强化学习之SAC算法

前言:
在正文开始之前,首先给大家介绍一个不错的人工智能学习教程:https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程,感兴趣的读者可以自行查阅。


一、引言

强化学习近年来在人工智能领域取得了显著的进展,特别是在连续控制任务中,**Soft Actor-Critic(SAC)**算法因其稳定性和高效性受到广泛关注。SAC是一种基于熵正则化的深度强化学习算法,它在策略更新中引入了熵项,鼓励策略的探索性,从而提高了学习效率。

二、算法原理详解

2.1 软策略与熵正则化

SAC算法的核心思想是通过最大化策略的期望奖励和策略熵之和,即:

J ( π ) = ∑ t = 0 ∞ E ( s t , a t ) ∼ ρ π [ r ( s t , a t ) + α H ( π ( ⋅ ∣ s t ) ) ] J(\pi) = \sum_{t=0}^{\infty} \mathbb{E}_{(s_t, a_t) \sim \rho_\pi} \left[ r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot|s_t)) \right] J(π)=t=0E(st,at)ρπ[r(st,at)+αH(π(st))]

其中, H ( π ( ⋅ ∣ s t ) ) \mathcal{H}(\pi(\cdot|s_t)) H(π(st)) 表示在状态 s t s_t st 下策略的熵, α \alpha α 是权衡奖励和熵的正则化系数。

2.2 策略目标

策略网络的目标是最小化以下损失函数:

J π = E s t ∼ D [ E a t ∼ π [ α log ⁡ ( π ( a t ∣ s t ) ) − Q ( s t , a t ) ] ] J_{\pi} = \mathbb{E}_{s_t \sim D} \left[ \mathbb{E}_{a_t \sim \pi} \left[ \alpha \log \left( \pi(a_t|s_t) \right) - Q(s_t, a_t) \right] \right] Jπ=EstD[Eatπ[αlog(π(atst))Q(st,at)]]

其中, Q ( s t , a t ) Q(s_t, a_t) Q(st,at) 是状态动作值函数, D D D 是经验回放池。

2.3 自动调整熵正则项

SAC算法引入了自动调整正则化系数 α \alpha α 的机制,使策略的熵接近目标熵 H target \mathcal{H}_{\text{target}} Htarget。温度参数的更新目标为:

J ( α ) = E a t ∼ π [ − α ( log ⁡ π ( a t ∣ s t ) + H target ) ] J(\alpha) = \mathbb{E}_{a_t \sim \pi} \left[ -\alpha \left( \log \pi(a_t|s_t) + \mathcal{H}_{\text{target}} \right) \right] J(α)=Eatπ[α(logπ(atst)+Htarget)]

2.4 双重 Q 网络

为了减小值函数估计的偏差,SAC采用了双重 Q 网络,即使用两个独立的 Q 网络 Q θ 1 Q_{\theta_1} Qθ1 Q θ 2 Q_{\theta_2} Qθ2。值函数的更新目标为:

J Q = E ( s t , a t , r t , s t + 1 ) ∼ D [ ( Q θ i ( s t , a t ) − y t ) 2 ] J_Q = \mathbb{E}_{(s_t, a_t, r_t, s_{t+1}) \sim D} \left[ \left( Q_{\theta_i}(s_t, a_t) - y_t \right)^2 \right] JQ=E(st,at,rt,st+1)D[(Qθi(st,at)yt)2]

其中, y t y_t yt 是目标值:

y t = r t + γ ( min ⁡ i = 1 , 2 Q θ ˉ i ( s t + 1 , a t + 1 ) − α log ⁡ π ( a t + 1 ∣ s t + 1 ) ) y_t = r_t + \gamma \left( \min_{i=1,2} Q_{\bar{\theta}_i}(s_{t+1}, a_{t+1}) - \alpha \log \pi(a_{t+1}|s_{t+1}) \right) yt=rt+γ(i=1,2minQθˉi(st+1,at+1)αlogπ(at+1st+1))

θ ˉ i \bar{\theta}_i θˉi 表示目标网络的参数, a t + 1 a_{t+1} at+1 由当前策略采样。

2.5 目标网络软更新

目标网络参数通过软更新方式进行更新:

θ ˉ i ← τ θ i + ( 1 − τ ) θ ˉ i \bar{\theta}_i \leftarrow \tau \theta_i + (1 - \tau) \bar{\theta}_i θˉiτθi+(1τ)θˉi

其中, τ \tau τ 是软更新系数,通常取值较小,如 0.005 0.005 0.005

三、案例分析

为了验证 SAC 算法的有效性,我们在经典的 Pendulum-v1 环境上进行了实验。该环境的目标是通过施加力矩,使摆杆保持竖直向上。

3.1 代码实现

以下是 SAC 算法在 Pendulum-v1 环境上的部分实现代码:

# SAC智能体
class SACContinuous:
    ''' 处理连续动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,
                                         action_bound).to(device)  # 策略网络
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第一个Q网络
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第二个Q网络
        self.target_critic_1 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第一个目标Q网络
        self.target_critic_2 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第二个目标Q网络
        # 令目标Q网络的初始参数和Q网络一样
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=actor_lr)
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(),
                                             lr=critic_lr)
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(),
                                             lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float, requires_grad=True, device=device)
        self.alpha_optimizer = optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        action, _ = self.actor(state)
        return action.cpu().detach().numpy()[0]

    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值
        with torch.no_grad():
            next_actions, next_log_prob = self.actor(next_states)
            entropy = -next_log_prob
            q1_value = self.target_critic_1(next_states, next_actions)
            q2_value = self.target_critic_2(next_states, next_actions)
            min_q = torch.min(q1_value, q2_value)
            target_q = rewards + (1 - dones) * self.gamma * (min_q + self.log_alpha.exp() * entropy)
        return target_q

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, replay_buffer):
        if len(replay_buffer) < BATCH_SIZE:
            return
        states, actions, rewards, next_states, dones = replay_buffer.sample()
        # 对奖励进行缩放
        rewards = (rewards + 8.0) / 8.0

        # 更新两个Q网络
        td_target = self.calc_target(rewards, next_states, dones)
        q1_loss = F.mse_loss(self.critic_1(states, actions), td_target)
        q2_loss = F.mse_loss(self.critic_2(states, actions), td_target)
        self.critic_1_optimizer.zero_grad()
        q1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        q2_loss.backward()
        self.critic_2_optimizer.step()

        # 更新策略网络
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        min_q = torch.min(q1_value, q2_value)
        actor_loss = (self.log_alpha.exp() * log_prob - min_q).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        # 更新目标网络
        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)

3.2 结果分析

运行上述代码,可以观察到智能体的回报随着训练的进行逐步提升。

Episode 10	Average Score: -1391.52
Episode 20	Average Score: -661.16
Episode 30	Average Score: -136.22
Episode 40	Average Score: -120.74
Episode 50	Average Score: -119.58
Episode 60	Average Score: -119.08
Episode 70	Average Score: -118.90
Episode 80	Average Score: -117.80
Episode 90	Average Score: -117.68
Episode 100	Average Score: -117.76
Episode 110	Average Score: -117.54
Episode 120	Average Score: -117.55
Episode 130	Average Score: -117.56
Episode 140	Average Score: -117.53
Episode 150	Average Score: -117.68
Episode 160	Average Score: -117.63
Episode 170	Average Score: -117.80
Episode 180	Average Score: -117.81
Episode 190	Average Score: -117.32
Episode 200	Average Score: -117.41

绘制的学习曲线如下所示:

从图中可以看到,智能体在大约 25 个回合后,达到了较高的平均回报,表明 SAC 算法在连续动作空间的控制任务中具有良好的性能。

四、总结

本文详细介绍了 Soft Actor-Critic 算法的原理,并在 Pendulum-v1 环境上进行了实验验证。SAC 通过引入策略熵和自动温度调整机制,实现了高效稳定的策略学习,非常适合处理高维连续动作空间的强化学习任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

抱抱宝

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值