Stable-Baselines 强化学习库实战指南
本文将通过多个实战案例,深入讲解如何使用 Stable-Baselines 这一强大的强化学习库来训练智能体。我们将从基础用法开始,逐步深入到高级特性,帮助读者掌握强化学习的实际应用技巧。
基础入门:训练、保存与加载模型
让我们从最基础的例子开始 - 在 LunarLander 环境中训练一个 DQN 智能体。
环境准备
LunarLander 是一个经典的强化学习环境,模拟登月器着陆过程。要使用这个环境,需要先安装 Box2D 物理引擎:
apt install swig
pip install box2d box2d-kengz
代码实现
import gym
from stable_baselines import DQN
from stable_baselines.common.evaluation import evaluate_policy
# 创建环境
env = gym.make('LunarLander-v2')
# 初始化DQN智能体
model = DQN('MlpPolicy', env, learning_rate=1e-3, prioritized_replay=True, verbose=1)
# 训练智能体
model.learn(total_timesteps=int(2e5))
# 保存模型
model.save("dqn_lunar")
del model # 删除模型以演示加载过程
# 加载训练好的模型
model = DQN.load("dqn_lunar")
# 评估模型性能
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# 运行训练好的智能体
obs = env.reset()
for i in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
关键点说明
learn()
方法用于训练模型,total_timesteps
参数控制训练步数save()
和load()
方法用于模型的保存和加载evaluate_policy()
可以方便地评估模型性能- 使用
predict()
方法让智能体与环境交互
多进程训练:向量化环境
强化学习训练通常非常耗时,使用多进程可以显著加速训练过程。Stable-Baselines 提供了 SubprocVecEnv
和 make_vec_env
等工具来简化多进程环境的创建。
CartPole 多进程示例
import gym
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import ACKTR
def make_env(env_id, rank, seed=0):
def _init():
env = gym.make(env_id)
env.seed(seed + rank)
return env
return _init
env_id = "CartPole-v1"
num_cpu = 4 # 使用4个进程
# 创建向量化环境
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# 或者使用更简便的make_vec_env
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0)
model = ACKTR(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
优势分析
- 多进程并行训练大幅提高数据采集效率
- 不同环境实例使用不同随机种子,增加数据多样性
- 向量化环境自动处理各进程间的同步问题
训练监控与回调函数
在长时间训练过程中,监控训练进度并保存最佳模型非常重要。Stable-Baselines 提供了回调函数机制来实现这一功能。
自定义回调示例
from stable_baselines.common.callbacks import BaseCallback
import numpy as np
import os
class SaveOnBestTrainingRewardCallback(BaseCallback):
def __init__(self, check_freq, log_dir, verbose=1):
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
self.save_path = os.path.join(log_dir, 'best_model')
self.best_mean_reward = -np.inf
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# 计算最近100个episode的平均奖励
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
if len(x) > 0:
mean_reward = np.mean(y[-100:])
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
self.model.save(self.save_path)
return True
# 使用回调
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir='./logs/')
model.learn(total_timesteps=int(1e5), callback=callback)
回调函数高级用法
- 可以记录训练指标到TensorBoard
- 实现早停机制防止过拟合
- 动态调整超参数
Atari游戏训练
Stable-Baselines 提供了专门针对Atari游戏的预处理工具,简化了训练流程。
Pong游戏训练示例
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack
from stable_baselines import ACER
# 创建预处理后的Atari环境
env = make_atari_env('PongNoFrameskip-v4', num_env=4, seed=0)
# 帧堆叠处理
env = VecFrameStack(env, n_stack=4)
model = ACER('CnnPolicy', env, verbose=1)
model.learn(total_timesteps=25000)
Atari训练要点
make_atari_env
自动处理帧跳转、灰度化等预处理- 帧堆叠(VecFrameStack)让模型能够感知运动信息
- 使用CNN策略处理图像输入
自定义策略网络
Stable-Baselines 允许用户自定义神经网络结构,满足特定任务需求。
自定义MLP策略
from stable_baselines.common.policies import FeedForwardPolicy
class CustomPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
net_arch=[dict(pi=[128, 128, 128], vf=[128, 128, 128])],
feature_extraction="mlp")
model = A2C(CustomPolicy, 'LunarLander-v2')
model.learn(total_timesteps=100000)
自定义策略说明
net_arch
定义策略网络和价值网络结构- 可以分别为pi(策略)和vf(价值)函数设计不同结构
- 支持CNN、LSTM等多种网络类型
参数访问与修改
Stable-Baselines 提供了直接访问和修改模型参数的接口,这在实现进化策略等算法时非常有用。
参数操作示例
# 获取当前参数
params = model.get_parameters()
# 修改参数
new_params = {k: v + np.random.normal(size=v.shape) for k,v in params.items()}
# 加载新参数
model.load_parameters(new_params)
循环策略(LSTM)
对于部分可观测环境,使用循环策略可以取得更好效果。
LSTM策略使用
model = PPO2('MlpLstmPolicy', 'CartPole-v1', nminibatches=1)
# 预测时需要维护状态
state = None
done = [False]
action, state = model.predict(obs, state=state, mask=done)
循环策略要点
- 测试时必须使用与训练时相同数量的环境
- 需要通过mask重置LSTM状态
- 适合具有时间依赖性的任务
总结
通过本文的多个实战案例,我们全面介绍了 Stable-Baselines 的主要功能和使用技巧。从基础操作到高级特性,这套工具链为强化学习研究和应用提供了强大支持。无论是简单的经典控制问题,还是复杂的Atari游戏,Stable-Baselines 都能提供高效的解决方案。
读者可以根据实际需求,选择合适的算法和环境,结合回调函数、参数调整等高级功能,开发出性能优异的强化学习智能体。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考