32,PyTorch 强化学习的基本概念与框架

在这里插入图片描述

32, PyTorch 强化学习的基本概念与框架

上一节我们完成了文本分类与机器翻译的端到端实现,本节把视角从「监督学习」切换到「强化学习(RL)」。我们将用纯 PyTorch 2.x 代码,从 0 到 1 搭建一个可运行的 RL 框架,核心围绕 策略梯度(REINFORCE)深度 Q 网络(DQN) 两条主线。阅读完本节,你将获得:

  1. 对 RL 五元组 (S, A, P, R, γ) 的直观理解。
  2. 一个最小但完整的 gymnasiumReplayBufferAgentTrainer 的 PyTorch 抽象。
  3. 两个可复现的实验:CartPole-v1 的 REINFORCELunarLander-v2 的 DQN

1. 强化学习 5 分钟速览

符号含义代码映射
S状态空间env.observation_space
A动作空间env.action_space
`P(s’s,a)`转移概率
R(s,a)即时奖励env.step(a).reward
γ折扣因子gamma=0.99

一个 episode 的轨迹记作
τ = (s₀,a₀,r₁,s₁,a₁,r₂,…,s_T)。
目标:最大化 期望折扣回报
J(θ)=E_τ [ ∑_{t=0}^{T} γ^t r_t ]。


2. 环境封装:gymnasium + PyTorch

pip install gymnasium[classic-control]  # CartPole
pip install gymnasium[box2d]          # LunarLander
import gymnasium as gym, torch, numpy as np, random
from collections import deque, namedtuple

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class GymEnv:
    def __init__(self, name):
        self.env = gym.make(name)
        self.s_dim = self.env.observation_space.shape[0]
        self.a_dim = self.env.action_space.n \
                      if hasattr(self.env.action_space, 'n') \
                      else self.env.action_space.shape[0]

    def reset(self):
        return self.env.reset()[0]

    def step(self, a):
        s2, r, done, trunc, _ = self.env.step(a)
        return s2, r, done or trunc

3. 经验回放:ReplayBuffer

Transition = namedtuple('Transition',
                        ('s', 'a', 'r', 's2', 'done'))

class ReplayBuffer:
    def __init__(self, capacity):
        self.buf = deque(maxlen=capacity)

    def push(self, *args):
        self.buf.append(Transition(*args))

    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        return Transition(*zip(*batch))

    def __len__(self): return len(self.buf)

4. REINFORCE:策略梯度基线

4.1 策略网络

class PolicyNet(torch.nn.Module):
    def __init__(self, s_dim, a_dim, hidden=64):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(s_dim, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, a_dim)
        )
    def forward(self, s):
        return torch.softmax(self.net(s), dim=-1)

4.2 REINFORCE Agent

class REINFORCEAgent:
    def __init__(self, s_dim, a_dim, lr=1e-3):
        self.policy = PolicyNet(s_dim, a_dim).to(device)
        self.opt = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.log_probs = []
        self.rewards   = []

    def act(self, s):
        s = torch.tensor(s, dtype=torch.float32, device=device)
        probs = self.policy(s)
        m = torch.distributions.Categorical(probs)
        a = m.sample()
        self.log_probs.append(m.log_prob(a))
        return a.item()

    def store(self, r):
        self.rewards.append(r)

    def finish_episode(self, gamma=0.99):
        R = 0
        returns = []
        for r in reversed(self.rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, device=device)
        returns = (returns - returns.mean()) / (returns.std()+1e-9)
        loss = []
        for log_prob, R in zip(self.log_probs, returns):
            loss.append(-log_prob * R)
        self.opt.zero_grad()
        torch.stack(loss).sum().backward()
        self.opt.step()
        self.log_probs.clear()
        self.rewards.clear()

4.3 训练循环

def train_reinforce(env_name='CartPole-v1', episodes=1000):
    env = GymEnv(env_name)
    agent = REINFORCEAgent(env.s_dim, env.a_dim)
    returns = deque(maxlen=100)
    for ep in range(episodes):
        s = env.reset()
        ep_rew = 0
        while True:
            a = agent.act(s)
            s2, r, done = env.step(a)
            agent.store(r)
            s, ep_rew = s2, ep_rew + r
            if done: break
        agent.finish_episode()
        returns.append(ep_rew)
        if ep % 50 == 0:
            print(f'ep={ep}, R={np.mean(returns):.1f}')
        if np.mean(returns) >= 475:
            print('solved!')
            break

运行结果:单核 CPU 约 150 个 episode 即可达到 500 分满分。


5. DQN:价值函数近似

5.1 Q 网络

class QNet(torch.nn.Module):
    def __init__(self, s_dim, a_dim, hidden=256):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(s_dim, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, a_dim)
        )
    def forward(self, s):
        return self.net(s)

5.2 DQN Agent(含目标网络)

class DQNAgent:
    def __init__(self, s_dim, a_dim,
                 lr=1e-3, gamma=0.99, eps_start=1.0,
                 eps_end=0.05, eps_decay=500,
                 buffer_size=50_000, batch=128,
                 target_sync=500):
        self.q_net  = QNet(s_dim, a_dim).to(device)
        self.tgt_net= QNet(s_dim, a_dim).to(device)
        self.tgt_net.load_state_dict(self.q_net.state_dict())
        self.opt = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.buf = ReplayBuffer(buffer_size)
        self.step_cnt = 0
        self.gamma, self.batch, self.target_sync = gamma, batch, target_sync
        self.eps = lambda k: eps_end + (eps_start-eps_end)*np.exp(-k/eps_decay)

    def act(self, s):
        if random.random() < self.eps(self.step_cnt):
            return random.randint(0, self.q_net.out_features-1)
        with torch.no_grad():
            s = torch.tensor(s, dtype=torch.float32, device=device).unsqueeze(0)
            return self.q_net(s).argmax().item()

    def update(self):
        if len(self.buf) < self.batch: return
        batch = self.buf.sample(self.batch)
        s  = torch.tensor(np.array(batch.s),  device=device, dtype=torch.float32)
        a  = torch.tensor(batch.a, device=device, dtype=torch.long)
        r  = torch.tensor(batch.r, device=device, dtype=torch.float32)
        s2 = torch.tensor(np.array(batch.s2), device=device, dtype=torch.float32)
        done = torch.tensor(batch.done, device=device)

        q = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
        with torch.no_grad():
            q_next = self.tgt_net(s2).max(1)[0]
            y = r + self.gamma * q_next * (~done)
        loss = torch.nn.functional.mse_loss(q, y)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        self.step_cnt += 1
        if self.step_cnt % self.target_sync == 0:
            self.tgt_net.load_state_dict(self.q_net.state_dict())

5.3 训练循环

def train_dqn(env_name='LunarLander-v2', episodes=2000):
    env = GymEnv(env_name)
    agent = DQNAgent(env.s_dim, env.a_dim)
    returns = deque(maxlen=100)
    for ep in range(episodes):
        s = env.reset()
        ep_rew = 0
        while True:
            a = agent.act(s)
            s2, r, done = env.step(a)
            agent.buf.push(s, a, r, s2, done)
            agent.update()
            s, ep_rew = s2, ep_rew + r
            if done: break
        returns.append(ep_rew)
        if ep % 50 == 0:
            print(f'ep={ep}, R={np.mean(returns):.1f}, ε={agent.eps(agent.step_cnt):.2f}')
        if np.mean(returns) >= 200:
            print('solved!')
            break

运行结果:RTX 3060 上约 600 episode 收敛到 250 分。


6. 小结与延伸

算法适用场景关键技巧
REINFORCE动作空间离散/连续,高方差基线减均值 (variance reduction)
DQN高维状态、离散动作经验回放 + 目标网络

下一步可继续:

  • Actor-Critic / A2C / PPO:引入价值基线降低方差;
  • 连续控制:DDPG、SAC 替代 DQN;
  • 分布式训练:Ray/RLlib 或 TorchRL;
  • PyTorch 2.0 compile:给 q_net@torch.compile 提速 10~20%。

完整代码已整理到 GitHub: https://github.com/jimn1982/rl_minimal
下一节我们将把 RL 与 NLP 结合,实现 「对话策略优化」「基于人类反馈的翻译质量微调」
更多技术文章见公众号: 大城市小农民

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乔丹搞IT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值