33,PyTorch 常见强化学习算法介绍

在这里插入图片描述

33, PyTorch 常见强化学习算法介绍

紧接上一节“32, PyTorch 强化学习的基本概念与框架”,本节把 REINFORCE 与 DQN 两张“入场券”进一步展开,系统梳理 PyTorch 社区里最常用、最稳定的 6 大算法族:

  1. Actor-Critic(A2C)
  2. 近端策略优化(PPO)
  3. 深度确定性策略梯度(DDPG)
  4. 双延迟 DDPG(TD3)
  5. 软演员-评论家(SAC)
  6. 可扩展分布式 PPO(SD-PPO / TorchRL)

每个算法给出:适用场景 → 核心公式 → PyTorch 关键实现片段 → 完整训练脚本路径。全部代码可在 GitHub jimn1982/rl_algorithms 一键复现。


1. Actor-Critic(A2C)

| 适用场景 | 离散或连续动作空间,单/多环境并行,收敛速度快于 REINFORCE |
| 核心思想 | 用 Critic 估计 V(s) 作为基线,降低 Actor 梯度方差 |
| 损失函数 |

  • Actor:-log π(a|s) * (R_t - V(s_t))
  • Critic:MSE(V(s_t), R_t)
    | PyTorch 片段(单进程简化版) |
class ActorCriticNet(nn.Module):
    def __init__(self, s_dim, a_dim, hidden=128):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(s_dim, hidden), nn.ReLU()
        )
        self.actor = nn.Linear(hidden, a_dim)
        self.critic = nn.Linear(hidden, 1)

    def forward(self, s):
        h = self.backbone(s)
        logits, v = self.actor(h), self.critic(h).squeeze(-1)
        return logits, v

def a2c_step(batch, net, opt, gamma=0.99, entropy_coef=0.01):
    s, a, r, s2, done = map(torch.tensor, zip(*batch))
    logits, v = net(s)
    _, v_next = net(s2)
    ret = r + gamma * v_next * (~done)
    adv = ret - v
    # Actor loss
    logp = F.log_softmax(logits, -1).gather(1, a.unsqueeze(1)).squeeze()
    actor_loss = -(logp * adv.detach()).mean()
    # Critic loss
    critic_loss = F.mse_loss(v, ret.detach())
    # Entropy bonus
    entropy = -(F.softmax(logits, -1) * F.log_softmax(logits, -1)).sum(-1).mean()
    loss = actor_loss + 0.5 * critic_loss - entropy_coef * entropy
    opt.zero_grad(); loss.backward(); opt.step()

| 训练脚本 | python a2c.py --env CartPole-v1 --n-envs 16
| 收敛速度 | 16 并行环境,CPU 上 30 秒左右 475 分 |


2. 近端策略优化(PPO)

| 适用场景 | 离散/连续动作,高维输入(像素)亦可,工业界最常用 |
| 核心思想 | 限制策略更新幅度,clip ratio r_t = π_new / π_old ∈ [1-ε, 1+ε] |
| 损失函数 |
L = min(r_t A_t, clip(r_t,1-ε,1+ε) A_t) + c1||V-V_target||² - c2 H(π)
| PyTorch 关键实现 |

class PPOAgent:
    def __init__(self, s_dim, a_dim, lr=3e-4, clip_eps=0.2, epochs=10, batch=512):
        self.ac = ActorCriticNet(s_dim, a_dim).to(device)
        self.opt = torch.optim.Adam(self.ac.parameters(), lr=lr)
        self.clip_eps, self.epochs, self.batch = clip_eps, epochs, batch

    def update(self, buf):
        states, actions, adv, ret, old_logp = buf.fetch()
        idx = torch.randperm(len(states))
        for _ in range(self.epochs):
            for start in range(0, len(states), self.batch):
                sl = idx[start:start+self.batch]
                logits, v = self.ac(states[sl])
                logp = F.log_softmax(logits,-1).gather(1, actions[sl]).squeeze()
                ratio = (logp - old_logp[sl]).exp()
                surr1 = ratio * adv[sl]
                surr2 = torch.clamp(ratio, 1-self.clip_eps, 1+self.clip_eps) * adv[sl]
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(v, ret[sl])
                entropy = -(F.softmax(logits,-1)*F.log_softmax(logits,-1)).sum(-1).mean()
                loss = actor_loss + 0.5*critic_loss - 0.01*entropy
                self.opt.zero_grad(); loss.backward(); self.opt.step()

| 训练脚本 | python ppo.py --env LunarLander-v2 --total-steps 3e6
| 收敛速度 | 1×RTX 3060,约 400k 步 250 分 |


3. 深度确定性策略梯度(DDPG)

| 适用场景 | 连续控制,低维或高维观测均可 |
| 核心思想 | 确定性 Actor + 目标网络 + 经验回放,借鉴 DQN |
| 关键公式 |

  • Actor:μ_θ(s) → a
  • Critic:Q_φ(s,a) → R
  • 目标:y = r + γ Q_tgt(s', μ_tgt(s'))
    | PyTorch 片段 |
class DDPGAgent:
    def __init__(self, s_dim, a_dim, a_max=1.0, lr=1e-3, tau=0.005):
        self.actor = nn.Sequential(
            nn.Linear(s_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, a_dim), nn.Tanh()
        ).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.critic = nn.Sequential(
            nn.Linear(s_dim+a_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        ).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.opt_a = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.opt_c = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.tau = tau
        self.a_max = a_max

    def act(self, s, noise=0.1):
        with torch.no_grad():
            a = self.actor(torch.tensor(s, device=device).float())
        return (a.cpu().numpy() + noise*np.random.randn()).clip(-self.a_max, self.a_max)

    def soft_update(self, target, source):
        for tp, sp in zip(target.parameters(), source.parameters()):
            tp.data.copy_((1-self.tau)*tp.data + self.tau*sp.data)

| 训练脚本 | python ddpg.py --env Pendulum-v1
| 收敛曲线 | 50 000 步左右平均回报 -200 → -140


4. 双延迟 DDPG(TD3)

| 改进点 | 1. 双 Critic 取最小值;2. 延迟 Actor 更新;3. 目标策略平滑 |
| 训练脚本 | python td3.py --env Walker2d-v4
| 性能对比 | Walker2d 得分 DDPG 2500 → TD3 4500


5. 软演员-评论家(SAC)

| 适用场景 | 连续动作,样本高效,自动温度系数 α |
| 核心公式 |

  • 熵正则化目标:J = Σ_t E[r_t + γ(V(s_{t+1}) - α log π(a_{t+1}|s_{t+1}))]
  • 自动 α:α_loss = -log α * (log π + target_entropy).mean()
    | 代码亮点 |
class SACAgent:
    def __init__(self, s_dim, a_dim):
        self.policy = GaussianPolicy(s_dim, a_dim).to(device)  # 输出 μ,σ
        self.q1, self.q2 = Critic(s_dim, a_dim), Critic(s_dim, a_dim)
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_opt = torch.optim.Adam([self.log_alpha], lr=3e-4)

| 训练脚本 | python sac.py --env Humanoid-v4
| 样本效率 | Humanoid 1M 步可达 6000 分(PPO 通常需 3M+)


6. 可扩展分布式 PPO(SD-PPO / TorchRL)

| 目标 | 利用 TorchRL 的 ParallelEnvTensorDict,单机 8 核 CPU 线性加速 |
| 关键 API |

from torchrl.envs import ParallelEnv, GymEnv, TransformedEnv
base_env = ParallelEnv(8, lambda: GymEnv("CartPole-v1"))
env = TransformedEnv(base_env, Compose(...))

| 训练脚本 | torchrun --nproc_per_node=1 sd_ppo.py
| 性能 | 8 并行环境,1 分钟 500 分


7. 算法选型速查表

任务离散动作连续动作样本效率实现难度
A2C
PPO
DQN
DDPG
TD3
SAC极高

8. 小结与延伸

  • 把本节代码与上节“最小框架”对比,可直观看到“策略梯度 → Actor-Critic → 信任域/熵正则化”的进化路径。
  • 下一步实战:
    1. MuJoCo 连续环境测试 TD3/SAC;
    2. torch.compileGaussianPolicy 加速 15%;
    3. 接入 Hydra 做超参数搜索,找出 Humanoid 的最优 α 曲线。

完整代码 & 日志:git clone https://github.com/jimn1982/rl_algorithms
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值