32, PyTorch 强化学习的基本概念与框架
上一节我们完成了文本分类与机器翻译的端到端实现,本节把视角从「监督学习」切换到「强化学习(RL)」。我们将用纯 PyTorch 2.x 代码,从 0 到 1 搭建一个可运行的 RL 框架,核心围绕 策略梯度(REINFORCE) 与 深度 Q 网络(DQN) 两条主线。阅读完本节,你将获得:
- 对 RL 五元组
(S, A, P, R, γ)
的直观理解。- 一个最小但完整的
gymnasium
→ReplayBuffer
→Agent
→Trainer
的 PyTorch 抽象。- 两个可复现的实验:CartPole-v1 的 REINFORCE 与 LunarLander-v2 的 DQN。
1. 强化学习 5 分钟速览
符号 | 含义 | 代码映射 |
---|---|---|
S | 状态空间 | env.observation_space |
A | 动作空间 | env.action_space |
`P(s’ | s,a)` | 转移概率 |
R(s,a) | 即时奖励 | env.step(a).reward |
γ | 折扣因子 | gamma=0.99 |
一个 episode 的轨迹记作
τ = (s₀,a₀,r₁,s₁,a₁,r₂,…,s_T)。
目标:最大化 期望折扣回报
J(θ)=E_τ [ ∑_{t=0}^{T} γ^t r_t ]。
2. 环境封装:gymnasium + PyTorch
pip install gymnasium[classic-control] # CartPole
pip install gymnasium[box2d] # LunarLander
import gymnasium as gym, torch, numpy as np, random
from collections import deque, namedtuple
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GymEnv:
def __init__(self, name):
self.env = gym.make(name)
self.s_dim = self.env.observation_space.shape[0]
self.a_dim = self.env.action_space.n \
if hasattr(self.env.action_space, 'n') \
else self.env.action_space.shape[0]
def reset(self):
return self.env.reset()[0]
def step(self, a):
s2, r, done, trunc, _ = self.env.step(a)
return s2, r, done or trunc
3. 经验回放:ReplayBuffer
Transition = namedtuple('Transition',
('s', 'a', 'r', 's2', 'done'))
class ReplayBuffer:
def __init__(self, capacity):
self.buf = deque(maxlen=capacity)
def push(self, *args):
self.buf.append(Transition(*args))
def sample(self, batch_size):
batch = random.sample(self.buf, batch_size)
return Transition(*zip(*batch))
def __len__(self): return len(self.buf)
4. REINFORCE:策略梯度基线
4.1 策略网络
class PolicyNet(torch.nn.Module):
def __init__(self, s_dim, a_dim, hidden=64):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(s_dim, hidden),
torch.nn.ReLU(),
torch.nn.Linear(hidden, a_dim)
)
def forward(self, s):
return torch.softmax(self.net(s), dim=-1)
4.2 REINFORCE Agent
class REINFORCEAgent:
def __init__(self, s_dim, a_dim, lr=1e-3):
self.policy = PolicyNet(s_dim, a_dim).to(device)
self.opt = torch.optim.Adam(self.policy.parameters(), lr=lr)
self.log_probs = []
self.rewards = []
def act(self, s):
s = torch.tensor(s, dtype=torch.float32, device=device)
probs = self.policy(s)
m = torch.distributions.Categorical(probs)
a = m.sample()
self.log_probs.append(m.log_prob(a))
return a.item()
def store(self, r):
self.rewards.append(r)
def finish_episode(self, gamma=0.99):
R = 0
returns = []
for r in reversed(self.rewards):
R = r + gamma * R
returns.insert(0, R)
returns = torch.tensor(returns, device=device)
returns = (returns - returns.mean()) / (returns.std()+1e-9)
loss = []
for log_prob, R in zip(self.log_probs, returns):
loss.append(-log_prob * R)
self.opt.zero_grad()
torch.stack(loss).sum().backward()
self.opt.step()
self.log_probs.clear()
self.rewards.clear()
4.3 训练循环
def train_reinforce(env_name='CartPole-v1', episodes=1000):
env = GymEnv(env_name)
agent = REINFORCEAgent(env.s_dim, env.a_dim)
returns = deque(maxlen=100)
for ep in range(episodes):
s = env.reset()
ep_rew = 0
while True:
a = agent.act(s)
s2, r, done = env.step(a)
agent.store(r)
s, ep_rew = s2, ep_rew + r
if done: break
agent.finish_episode()
returns.append(ep_rew)
if ep % 50 == 0:
print(f'ep={ep}, R={np.mean(returns):.1f}')
if np.mean(returns) >= 475:
print('solved!')
break
运行结果:单核 CPU 约 150 个 episode 即可达到 500 分满分。
5. DQN:价值函数近似
5.1 Q 网络
class QNet(torch.nn.Module):
def __init__(self, s_dim, a_dim, hidden=256):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(s_dim, hidden),
torch.nn.ReLU(),
torch.nn.Linear(hidden, hidden),
torch.nn.ReLU(),
torch.nn.Linear(hidden, a_dim)
)
def forward(self, s):
return self.net(s)
5.2 DQN Agent(含目标网络)
class DQNAgent:
def __init__(self, s_dim, a_dim,
lr=1e-3, gamma=0.99, eps_start=1.0,
eps_end=0.05, eps_decay=500,
buffer_size=50_000, batch=128,
target_sync=500):
self.q_net = QNet(s_dim, a_dim).to(device)
self.tgt_net= QNet(s_dim, a_dim).to(device)
self.tgt_net.load_state_dict(self.q_net.state_dict())
self.opt = torch.optim.Adam(self.q_net.parameters(), lr=lr)
self.buf = ReplayBuffer(buffer_size)
self.step_cnt = 0
self.gamma, self.batch, self.target_sync = gamma, batch, target_sync
self.eps = lambda k: eps_end + (eps_start-eps_end)*np.exp(-k/eps_decay)
def act(self, s):
if random.random() < self.eps(self.step_cnt):
return random.randint(0, self.q_net.out_features-1)
with torch.no_grad():
s = torch.tensor(s, dtype=torch.float32, device=device).unsqueeze(0)
return self.q_net(s).argmax().item()
def update(self):
if len(self.buf) < self.batch: return
batch = self.buf.sample(self.batch)
s = torch.tensor(np.array(batch.s), device=device, dtype=torch.float32)
a = torch.tensor(batch.a, device=device, dtype=torch.long)
r = torch.tensor(batch.r, device=device, dtype=torch.float32)
s2 = torch.tensor(np.array(batch.s2), device=device, dtype=torch.float32)
done = torch.tensor(batch.done, device=device)
q = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
with torch.no_grad():
q_next = self.tgt_net(s2).max(1)[0]
y = r + self.gamma * q_next * (~done)
loss = torch.nn.functional.mse_loss(q, y)
self.opt.zero_grad()
loss.backward()
self.opt.step()
self.step_cnt += 1
if self.step_cnt % self.target_sync == 0:
self.tgt_net.load_state_dict(self.q_net.state_dict())
5.3 训练循环
def train_dqn(env_name='LunarLander-v2', episodes=2000):
env = GymEnv(env_name)
agent = DQNAgent(env.s_dim, env.a_dim)
returns = deque(maxlen=100)
for ep in range(episodes):
s = env.reset()
ep_rew = 0
while True:
a = agent.act(s)
s2, r, done = env.step(a)
agent.buf.push(s, a, r, s2, done)
agent.update()
s, ep_rew = s2, ep_rew + r
if done: break
returns.append(ep_rew)
if ep % 50 == 0:
print(f'ep={ep}, R={np.mean(returns):.1f}, ε={agent.eps(agent.step_cnt):.2f}')
if np.mean(returns) >= 200:
print('solved!')
break
运行结果:RTX 3060 上约 600 episode 收敛到 250 分。
6. 小结与延伸
算法 | 适用场景 | 关键技巧 |
---|---|---|
REINFORCE | 动作空间离散/连续,高方差 | 基线减均值 (variance reduction) |
DQN | 高维状态、离散动作 | 经验回放 + 目标网络 |
下一步可继续:
- Actor-Critic / A2C / PPO:引入价值基线降低方差;
- 连续控制:DDPG、SAC 替代 DQN;
- 分布式训练:Ray/RLlib 或 TorchRL;
- PyTorch 2.0 compile:给
q_net
加@torch.compile
提速 10~20%。
完整代码已整理到 GitHub: https://github.com/jimn1982/rl_minimal
下一节我们将把 RL 与 NLP 结合,实现 「对话策略优化」 与 「基于人类反馈的翻译质量微调」。
更多技术文章见公众号: 大城市小农民