OpenAI SpinningUp项目中的实验输出与模型加载指南
概述
在强化学习项目中,训练过程的输出管理和模型加载是至关重要的环节。本文将深入解析SpinningUp项目中的实验输出结构、保存机制以及如何加载和使用训练好的策略模型。
实验输出结构详解
SpinningUp算法的每次训练运行都会生成一系列输出文件,这些文件包含了训练过程的完整记录。了解这些输出的结构和用途对于后续分析和模型使用至关重要。
核心输出目录结构
每个训练运行会生成以下文件和目录:
-
框架相关保存目录
pyt_save/
:PyTorch实现专用,包含恢复训练好的智能体和价值函数所需的所有内容tf1_save/
:TensorFlow实现专用,功能与PyTorch版本类似
-
训练记录文件
config.json
:完整的训练参数配置记录progress.txt
:训练过程中记录的各项指标vars.pkl
:算法状态保存文件(主要保存环境副本)
重要注意事项
- 环境保存可能失败:某些环境(特别是旧版Gym的Box2D环境)无法被pickle序列化,导致
vars.pkl
为空 - 文件结构变更历史:TensorFlow的保存目录从
simple_save/
变更为tf1_save/
- 手动使用建议:通常只需直接查看
config.json
文件,其他文件应通过专用工具访问
框架特定保存结构
PyTorch保存结构
pyt_save
目录包含:
model.pt
:通过torch.save
创建的PyTorch模型文件,可恢复为带有act
方法的ActorCritic对象
TensorFlow保存结构
tf1_save
目录包含:
variables/
:TensorFlow Saver的输出目录model_info.pkl
:包含帮助解包保存模型的信息字典saved_model.pb
:TensorFlow SavedModel所需的协议缓冲区文件
输出目录位置管理
默认情况下,实验结果保存在SpinningUp包同级目录下的data
文件夹中。用户可以通过修改spinup/user_config.py
中的DEFAULT_DATA_DIR
来更改默认结果目录。
加载和运行训练好的策略
环境成功保存的情况
当环境与智能体一起成功保存时,可以使用以下命令轻松测试训练好的智能体:
python -m spinup.run test_policy path/to/output_directory
常用测试选项
-
最大测试长度 (
-l/--len
)- 类型:整数
- 默认值:0(无限制)
- 注意:设置为0不会覆盖Gym环境自带的TimeLimit包装器限制
-
测试回合数 (
-n/--episodes
)- 类型:整数
- 默认值:100
-
禁用渲染 (
-nr/--norender
)- 作用:仅打印回合回报和长度,不显示渲染画面
- 适用场景:需要快速评估性能而不关心可视化时
-
指定迭代版本 (
-i/--itr
)- 类型:整数
- 默认值:-1(使用最新快照)
- 注意:默认算法配置不支持此功能,需要修改代码启用多快照保存
-
确定性策略 (
-d/--deterministic
)- 特殊用途:仅适用于SAC算法
- 作用:使用确定性均值策略而非训练时的随机策略进行评估
环境保存失败的情况
当环境未能成功保存时,可以通过以下方式手动加载和测试策略:
from spinup.utils.test_policy import load_policy_and_env, run_policy
import your_env
_, get_action = load_policy_and_env('/path/to/output_directory')
env = your_env.make()
run_policy(env, get_action)
使用训练好的价值函数
test_policy.py
工具不支持直接查看训练好的价值函数。如需使用:
- PyTorch实现:使用
torch.load
加载模型文件,参考算法文档了解ActorCritic对象的模块结构 - TensorFlow实现:使用
restore_tf_graph
函数加载计算图,参考算法文档了解保存的函数
最佳实践建议
- 参数记录:始终检查
config.json
以确保训练参数被正确记录 - 环境兼容性:在开始长时间训练前,先验证环境能否被正确保存
- 性能评估:对于无界面服务器,使用
--norender
选项进行快速评估 - 版本控制:考虑修改代码以支持多快照保存,便于分析训练过程中的策略演变
通过深入理解SpinningUp的实验输出结构和加载机制,研究人员可以更高效地管理和分析强化学习实验,为算法改进和性能优化奠定坚实基础。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考