PyTorch RL中的数据收集器(Collectors)详解
数据收集器概述
在强化学习框架中,数据收集器(Collectors)扮演着至关重要的角色,它们类似于PyTorch中的DataLoader,但有两点关键区别:
- 数据收集器从动态数据源(如环境交互)中收集数据
- 数据收集过程需要使用模型(通常是正在训练的模型的一个版本)
PyTorch RL中的数据收集器主要接受两个参数:环境(或环境构造器列表)和策略。它们会迭代执行环境步骤和策略查询,在达到预定义步数后向用户交付收集到的数据堆栈。当环境达到终止状态或超过预设步数时,环境会被重置。
数据收集器的类型与选择
由于数据收集可能是计算密集型过程,合理配置执行超参数至关重要。首要考虑的是数据收集应与优化步骤串行执行还是并行执行。
同步数据收集器
- SyncDataCollector:在训练工作线程上执行数据收集
- MultiSyncDataCollector:将工作负载分配到多个工作线程,并将结果聚合后交付给训练线程
异步数据收集器
- MultiaSyncDataCollector:在多个工作线程上执行数据收集,并交付最先收集到的批次数据。这种执行会持续且与网络训练同时进行,意味着用于数据收集的策略权重可能略微滞后于训练线程上的策略配置。
对于远程执行的rollout(MultiSyncDataCollector或MultiaSyncDataCollector),需要使用collector.update_policy_weights_()
同步远程策略权重,或在构造函数中设置update_at_each_batch=True
。
关键配置参数
计算设备配置
在远程设置中,需要考虑数据收集设备以及环境和策略操作的执行设备。例如:
- CPU上执行的策略可能比CUDA设备上的慢
- 当多个推理工作线程同时运行时,跨可用设备分发计算工作负载可以加速收集或避免OOM错误
- 批大小和传递设备(存储数据等待传递给收集工作线程的设备)也会影响内存管理
关键控制参数:
devices
:控制执行设备(策略的设备)storing_device
:控制在rollout期间存储环境和数据的设备
其他重要参数
max_frames_per_traj
:调用env.reset()
前的帧数frames_per_batch
:每次迭代交付的帧数init_random_frames
:随机步数(调用env.rand_step()
的步数)reset_at_each_iter
:如果为True,每次批收集后重置环境split_trajs
:如果为True,轨迹将被分割并以填充的tensordict形式交付exploration_type
:与策略一起使用的探索策略reset_when_done
:环境在达到完成状态时是否应重置
收集器与批大小
不同收集器组织运行环境的方式不同,因此数据会具有不同的批大小。下表总结了收集数据时的预期情况:
| 收集器类型 | 单环境 | 批处理环境(n=P) | |------------|--------|-----------------| | SyncDataCollector | [T] | [P, T] | | MultiSyncDataCollector (n=B) | [B, T] | [B, P, T] | | MultiaSyncDataCollector (n=B) | [T] | [P, T] |
警告:MultiSyncDataCollector
不应与cat_results=0
一起使用,因为数据将沿批处理维度堆叠(对于批处理环境)或沿时间维度堆叠(对于单环境),这可能在交换时引入混淆。
权重同步与策略复制
在分布式和多进程环境中,确保所有策略实例与最新训练权重同步对于保持性能一致性至关重要。PyTorch RL提供了灵活的机制来跨不同设备和进程更新策略权重。
权重更新机制
权重同步过程通过WeightUpdaterBase
类实现,它提供了结构化接口用于实现自定义权重更新逻辑。每个收集器(服务器或工作线程)都应有一个WeightUpdaterBase
实例来处理与策略的权重同步。
扩展更新器类
API允许用户通过自定义实现扩展更新器类,这在涉及复杂网络架构或专用硬件设置时特别有用。通过实现这些基类中的抽象方法,用户可以定义如何检索、转换和应用权重。
收集器与回放缓冲区的互操作性
在需要从回放缓冲区采样单个转换的最简单场景中,构建收集器时只需很少关注。收集后展平数据就足以填充存储:
memory = ReplayBuffer(
storage=LazyTensorStorage(N),
transform=lambda data: data.reshape(-1))
for data in collector:
memory.extend(data)
如果需要收集轨迹切片,推荐的方法是创建多维缓冲区并使用SliceSampler
采样器类。必须确保传递给缓冲区的数据形状正确,time
和batch
维度清晰分离。
异步运行收集器
将回放缓冲区传递给收集器允许我们启动收集并摆脱收集器的迭代性质。要在后台运行数据收集器,只需运行start()
方法:
collector = SyncDataCollector(..., replay_buffer=rb)
collector.start()
time.sleep(10)
for i in range(optim_steps):
data = rb.sample()
# 训练循环其余部分
警告:异步运行收集器将收集与训练解耦,这意味着训练性能可能会因硬件、负载和其他因素而有很大不同。确保理解这可能如何影响您的算法!
分布式数据收集器
PyTorch RL提供了一组分布式数据收集器,支持多种后端('gloo'、'nccl'、'mpi')和启动器('ray'、submitit或torch.multiprocessing)。它们可以高效地用于同步或异步模式,在单个节点或多个节点上。
选择子收集器
所有分布式收集器都支持各种单机收集器。一般来说,多进程收集器的IO占用比需要每一步通信的并行环境要低。然而,模型规格在相反方向起作用,因为使用并行环境将导致策略(和/或转换)执行更快,因为这些操作将被向量化。
选择收集器设备
在CPU上使用并行环境和多进程环境时,通过共享内存缓冲区在进程间共享数据。根据所用机器的能力,这可能比在GPU上共享数据要慢得多。在实践中,这意味着在构建并行环境或收集器时使用device="cpu"
可能比使用device="cuda"
(当可用时)导致更慢的收集。
总结
PyTorch RL中的数据收集器提供了强大而灵活的工具来管理强化学习中的数据收集过程。通过理解不同类型的收集器、它们的配置参数以及与回放缓冲区的交互方式,开发者可以构建高效的数据管道,支持从简单到复杂的强化学习应用场景。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考