【torchrl】强化学习训练流程

在这里插入图片描述

1 采集数据阶段

上面这个循环是用来采集数据,并且加入到replay buffer中。最终获取的数据是
- s: 当前状态,或者observation
- a: 当前动作,后面重要性采样需要用到
- pa: 选择当前动作的概率,后面重要性采样用到
- r: 当前的奖励值
- s’: 下一个状态

有些算法可能会直接在这里估计state value,即这步的下面一步合并到这步中,那么就会增加state value, 和next state value
这里也可以采集多步,再加入到replay buffer中。

2 训练阶段

一般是采集一步数据,加入replay buffer,然后采样若干数据训练多次。也有是采集多步数据,然后训练多次。

2.1 critic损失计算

  1. 计算target value需要通过advatage+state value计算。
  2. advatage具体的计算方式是:
    adv = r + γ v ′ − v \text{adv} = r+\gamma v'-v adv=r+γvv
    这里的v和v’是当前value和下一步value,adv间接表示了动作的价值。
    当多步的adv加权求和之后,可以获得vae,它表示了更加优秀的动作价值估计,在计算actor loss时使用。

adv+v就是当前状态价值了,或者说 v t = r + γ v ′ v_t=r+\gamma v' vt=r+γv,这个值就是critic的目标。

2.2 actor损失的计算

  1. 首先通过当前的actor网络,估计出了动作action的概率pa
  2. 损失就是-vae(pa-pa’),这里pa’是之前选择动作action的概率。
  3. 对于PPO来说,可以对损失进行进一步裁剪。

以上的训练阶段的两个损失,均是通过PPOLoss计算的。具体可以看:https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py

问题

1 在收集之后马上计算state value,和从replay buffer取出来后再计算state value,两者有什么区别?

下面是gpt回答:

  1. 将状态值加入到Replay Buffer中

优点:

  • 稳定性:由于状态值是与状态和动作对应的估计值,这样存储在replay buffer中的值是基于当时的critic网络计算的,避免了因critic网络更新而引起的估计不稳定。
    减少计算开销:在每次使用replay buffer中的样本进行训练时,直接读取存储的状态值,可以减少重新计算的开销。

缺点:

  • 过时的估计:随着训练的进行,critic网络不断更新,存储在replay buffer中的状态值可能变得过时,不再准确反映当前的网络状态。
  • 存储空间:需要额外的存储空间来保存这些状态值,增加了内存的需求。
  1. 通过当前的Critic网络重新估计

优点:

  • 最新的估计:每次使用replay buffer中的样本时,使用当前的critic网络重新估计状态值,保证了状态值反映的是最新的网络状态。
  • 避免过时信息:由于每次都重新计算,可以避免使用过时的信息进行更新,提高了训练的有效性。

缺点:

  • 计算开销增加:每次从replay buffer中取样本时都需要通过当前的critic网络重新估计状态值,增加了计算开销。
  • 可能的估计不稳定:由于critic网络在训练过程中不断更新,状态值的估计可能会有较大的波动,导致训练的不稳定性。

两者应该都可以,在torchrl中也会在vae计算时候检查是否已经估计了state value,没有的话会自动帮你调用critic估计一下。

2 在网络前加上RNN是否破坏了马尔可夫性

如果只能获取有限状态,应该是不影响。如果是所有状态,则影响。

3 replay buffer应该存储什么

如果采集到数据,马上就计算state value,那么其实不需要保存state,也就是critic(no grad)这一步可以放在step之前,然后在replay buffer中不再存储state,而是state value。这两种方式都可以,看自己选择了。

### 使用 TorchRL 进行强化学习 #### TensorDict 的优势 TorchRL 通过 `TensorDict` 数据结构极大简化了强化学习代码的编写过程[^2]。这种数据结构允许开发者以更高效的方式处理多维张量,支持批量操作并能轻松转换成其他常用格式。 #### 安装与配置环境 为了开始使用 TorchRL 开发项目,需先安装 PyTorch 及其依赖项。可以通过官方文档获取最新的安装指南[^1]。对于特定硬件平台如 NVIDIA Jetson 设备,则可能还需要额外配置 OpenAI Gym 和 Gazebo 模拟器等工具链[^5]。 #### 创建基本框架 构建一个简单的深度 Q 学习模型作为入门案例是一个不错的选择。这涉及到定义好状态空间、动作空间以及设计合适的奖励机制;同时还要搭建起两个神经网络——一个是用于评估当前策略的好坏程度的价值网络(也叫作Q-network),另一个则是用来稳定训练过程的目标网络[^3]。 ```python import torch from torch import nn from torchrl.data import ReplayBuffer, TensorDictReplayBuffer from torchrl.modules.models.exploration import DQNModel class SimpleDQN(nn.Module): def __init__(self, input_size, output_size): super().__init__() self.model = nn.Sequential( nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, output_size) ) def forward(self, x): return self.model(x) env = ... # 初始化环境实例 policy_net = SimpleDQN(env.observation_space.shape[0], env.action_space.n).to(device) target_net = SimpleDQN(env.observation_space.shape[0], env.action_space.n).to(device) target_net.load_state_dict(policy_net.state_dict()) replay_buffer = TensorDictReplayBuffer(storage=LazyMemmapStorage(buffer_size)) ``` 此段代码展示了如何创建一个简易版的 Deep-Q Network (DQN),其中包含了两部分主要组件:一是负责预测给定状态下采取各行动所获得预期回报值大小的 policy network;二是定期更新参数并与前者同步变化但保持固定的 target network 。此外还设置了经验回放缓冲区以便于存储过往经历供后续采样复习之用。 #### 训练循环逻辑 在完成上述准备工作之后就可以进入正式的训练环节了。通常情况下会按照如下模式执行: - 收集来自环境交互产生的新样本; - 将这些最新收集到的经验存入 replay buffer 中; - 随机抽取一批历史记录来进行 mini-batch 更新; - 利用 Bellman 方程计算 TD error 并据此调整权重参数; - 周期性地复制 policy net 参数至 target net ,确保后者始终反映着前者的最优解形态。 关于 SAC 算法的重要性在于提高了连续控制系统中的表现力,并降低了实现难度,有助于推动更多应用场景下的落地实践和技术革新[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值