Gymnasium项目教程:如何记录智能体训练与评估过程
前言
在强化学习的研究和开发过程中,记录智能体的表现是至关重要的环节。Gymnasium提供了两个强大的封装器(Wrapper)来帮助我们完成这项工作:RecordEpisodeStatistics
和RecordVideo
。本文将详细介绍如何使用这两个工具来记录智能体的训练和评估过程。
核心封装器介绍
RecordEpisodeStatistics封装器
这个封装器用于记录每个回合(episode)的关键统计数据,包括:
- 总奖励(Total rewards)
- 回合长度(Episode length)
- 耗时(Time taken)
这些数据对于评估智能体性能至关重要,可以帮助开发者了解训练过程中的进步情况。
RecordVideo封装器
这个封装器能够将智能体的表现录制为mp4视频文件,基于环境的渲染输出。视频记录对于直观展示智能体行为非常有用,特别是在需要视觉化分析问题时。
评估模式下的记录
当我们需要评估一个已经训练好的智能体时,通常会希望记录多个回合的表现。以下是一个完整的示例代码:
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo
num_eval_episodes = 4
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="eval",
episode_trigger=lambda x: True)
env = RecordEpisodeStatistics(env, buffer_length=num_eval_episodes)
for episode_num in range(num_eval_episodes):
obs, info = env.reset()
episode_over = False
while not episode_over:
action = env.action_space.sample() # 这里替换为实际智能体的决策逻辑
obs, reward, terminated, truncated, info = env.step(action)
episode_over = terminated or truncated
env.close()
print(f'回合耗时: {env.time_queue}')
print(f'回合总奖励: {env.return_queue}')
print(f'回合长度: {env.length_queue}')
代码解析
-
RecordVideo
参数说明:video_folder
: 视频保存目录name_prefix
: 视频文件名前缀episode_trigger
: 触发记录的条件,这里设置为每个回合都记录
-
RecordEpisodeStatistics
参数说明:buffer_length
: 数据队列的最大长度,用于存储最近的回合统计数据
-
评估结束后,我们可以通过封装器提供的队列(
time_queue
、return_queue
、length_queue
)来获取所有回合的统计数据。
训练模式下的记录
在训练过程中,由于回合数量庞大,我们通常不需要记录每个回合的视频,而是定期记录。同时,我们仍然需要收集每个回合的统计数据。以下是训练模式下记录的示例:
import logging
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo
training_period = 250 # 每250个回合记录一次
num_training_episodes = 10_000 # 总训练回合数
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="training",
episode_trigger=lambda x: x % training_period == 0)
env = RecordEpisodeStatistics(env)
for episode_num in range(num_training_episodes):
obs, info = env.reset()
episode_over = False
while not episode_over:
action = env.action_space.sample() # 这里替换为实际智能体的决策逻辑
obs, reward, terminated, truncated, info = env.step(action)
episode_over = terminated or truncated
logging.info(f"回合-{episode_num}", info["episode"])
env.close()
代码解析
episode_trigger
参数现在设置为每250个回合记录一次视频- 使用Python的logging模块记录每个回合的统计数据
- 统计数据可以通过
info["episode"]
字典获取
性能优化建议
-
向量环境:对于评估多个回合,考虑使用向量环境(Vector Environment)并行执行多个环境实例,可以显著提高评估效率。
-
日志系统:除了Python自带的logging模块,还可以考虑集成更强大的日志系统如TensorBoard或Weights & Biases,它们提供了更丰富的数据可视化和分析功能。
-
视频压缩:当录制大量视频时,可以考虑在录制后自动压缩视频文件以节省存储空间。
常见问题解答
Q: 为什么我的视频没有生成? A: 请确保:
- 环境支持
rgb_array
渲染模式 - 指定的视频目录有写入权限
episode_trigger
条件设置正确
Q: 如何自定义视频质量? A: 目前Gymnasium没有直接提供视频质量参数,但你可以通过修改环境的分辨率来间接影响视频质量。
Q: 统计数据队列会占用多少内存? A: 内存占用取决于buffer_length
参数和每个回合的数据量,对于大多数环境来说内存占用可以忽略不计。
结语
通过合理使用Gymnasium提供的记录功能,开发者可以更全面地了解智能体的训练过程和最终表现。无论是用于学术研究还是实际应用开发,这些记录工具都能提供宝贵的反馈信息,帮助改进和优化智能体算法。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考