ray.rllib-入门实践-12:自定义policy

部署运行你感兴趣的模型镜像

        在本博客开始之前,先厘清一下几个概念之间的区别与联系:env,  agent,  model, algorithm, policy. 

        强化学习由两部分组成: 环境(env)和智能体(agent)。环境(env)提供观测值和奖励; agent读取观测值,输出动作或决策。agent是algorithm的类对象。 policy是algorithm的子类, 比如ppo, dqn等。因此,自定义policy本质上是自定义algorithm.  algorithm 主要由两部分组成: 网络结构(model)和损失函数(loss)。 网络结构(model)的自定义由上一个博客ray.rllib-入门实践-11: 自定义模型/网络 进行了介绍:在alrorithm外创建新的model类, 通过 AlgorithmConfig类传入algorithm。因此, c从实际操作上, 自定义algorithm就变成了自定义algorithm 的 loss.

        因此,本博客所提到的自定义policy, 本质上就是继承一个Algorithm, 并修改它的loss函数。

        与之前介绍的自定义env, 自定义model一样, 自定义policy也包含三个步骤:

        1. 继承某个Policy, 创建一个新Policy类, 修改它的损失函数。

        2. 把自己的Policy封装为一个Algorithm, 使ray可识别

        3. 配置使用自己的Policy.

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 自定义 policy

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch

## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):
    """PyTorch policy class used with PPO."""
    def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): 
        PPOTorchPolicy.__init__(self,observation_space,action_space,config)
        ## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config

    @override(PPOTorchPolicy) 
    def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):
        ## 原始损失
        original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。

        ## 新增自定义损失,这里以正则化损失作为示例
        addiontial_loss = torch.tensor(0.0) ## 自己定义的loss
        addiontial_loss = torch.tensor(0.)
        for param in model.parameters():
            addiontial_loss += torch.norm(param)
        ## 得到更新后的损失
        new_loss = original_loss + 0.01 * addiontial_loss
        return new_loss

二、 把自己的policy封装在一个算法中

## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):
    ## 重写 get_default_policy_class 函数, 使其返回自定义的policy 
    def get_default_policy_class(self, config):
        return MY_PPOTorchPolicy

三、使用自己的策略创建智能体,执行训练

配置方法1:

## 三、使用自己的策略创建智能体,执行训练
## method-1
from ray.tune.logger import pretty_print 
config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = config.build()
result = algo.train()
print(pretty_print(result))

配置方法2:

## 3. 使用新策略执行训练
## method-2
from ray.tune.logger import pretty_print 
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里使用自己的policy
result = algo.train()
print(pretty_print(result))

四、代码汇总:

import torch 
import gymnasium as gym 
from gymnasium import spaces
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from typing import Dict, List, Type, Union
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.logger import pretty_print 

## 1. 自定义 policy, 主要是改变 policy 的 loss 的计算  # 神经网络的损失函数
class MY_PPOTorchPolicy(PPOTorchPolicy):
    """PyTorch policy class used with PPO."""
    def __init__(self, observation_space:gym.spaces.Box, action_space:gym.spaces.Box, config:PPOConfig): 
        PPOTorchPolicy.__init__(self,observation_space,action_space,config)
        ## PPOTorchPolicy 内部对 PPOConfig 格式的config 执行了to_dict()操作,后面可以以 dict 的形式使用 config

    @override(PPOTorchPolicy) 
    def loss(self,model: ModelV2,dist_class: Type[ActionDistribution],train_batch: SampleBatch):
        ## 原始损失
        original_loss = super().loss(model, dist_class, train_batch) # PPO原来的损失函数, 也可以完全自定义新的loss函数, 但是非常不建议。

        ## 新增自定义损失,这里以正则化损失作为示例
        addiontial_loss = torch.tensor(0.0) ## 自己定义的loss
        addiontial_loss = torch.tensor(0.)
        for param in model.parameters():
            addiontial_loss += torch.norm(param)
        ## 得到更新后的损失
        new_loss = original_loss + 0.01 * addiontial_loss
        return new_loss
    
## 2. 把自己的 policy 封装在算法中: 
##    继承自PPO, 创建一个新的算法类, 默认调用的是自定义的policy
class MY_PPO(PPO):
    ## 重写 get_default_policy_class 函数, 使其返回自定义的policy 
    def get_default_policy_class(self, config):
        return MY_PPOTorchPolicy

## 三、使用自己的策略创建智能体,执行训练
## method-1
# config = PPOConfig(algo_class = MY_PPO) ## 配置使用自己的算法
# config = config.environment("CartPole-v1")
# config = config.rollouts(num_rollout_workers=2)
# config = config.framework(framework="torch") 
# algo = config.build()
# result = algo.train()
# print(pretty_print(result))

## method-2
config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 
algo = MY_PPO(config=config,)  ## 在这里配置使用自己的policy
result = algo.train()
print(pretty_print(result))

五、后记:

        如何既要修改网络结构,又要修改loss函数, 可以结合 上一篇博客 和本博客共同实现。

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

标题基于Python的汽车之家网站舆情分析系统研究AI更换标题第1章引言阐述汽车之家网站舆情分析的研究背景、意义、国内外研究现状、论文方法及创新点。1.1研究背景与意义说明汽车之家网站舆情分析对汽车行业及消费者的重要性。1.2国内外研究现状概述国内外在汽车舆情分析领域的研究进展与成果。1.3论文方法及创新点介绍本文采用的研究方法及相较于前人的创新之处。第2章相关理论总结和评述舆情分析、Python编程及网络爬虫相关理论。2.1舆情分析理论阐述舆情分析的基本概念、流程及关键技术。2.2Python编程基础介绍Python语言特点及其在数据分析中的应用。2.3网络爬虫技术说明网络爬虫的原理及在舆情数据收集中的应用。第3章系统设计详细描述基于Python的汽车之家网站舆情分析系统的设计方案。3.1系统架构设计给出系统的整体架构,包括数据收集、处理、分析及展示模块。3.2数据收集模块设计介绍如何利用网络爬虫技术收集汽车之家网站的舆情数据。3.3数据处理与分析模块设计阐述数据处理流程及舆情分析算法的选择与实现。第4章系统实现与测试介绍系统的实现过程及测试方法,确保系统稳定可靠。4.1系统实现环境列出系统实现所需的软件、硬件环境及开发工具。4.2系统实现过程详细描述系统各模块的实现步骤及代码实现细节。4.3系统测试方法介绍系统测试的方法、测试用例及测试结果分析。第5章研究结果与分析呈现系统运行结果,分析舆情数据,提出见解。5.1舆情数据可视化展示通过图表等形式展示舆情数据的分布、趋势等特征。5.2舆情分析结果解读对舆情分析结果进行解读,提出对汽车行业的见解。5.3对比方法分析将本系统与其他舆情分析系统进行对比,分析优劣。第6章结论与展望总结研究成果,提出未来研究方向。6.1研究结论概括本文的主要研究成果及对汽车之家网站舆情分析的贡献。6.2展望指出系统存在的不足及未来改进方向,展望舆情
【磁场】扩展卡尔曼滤波器用于利用高斯过程回归进行磁场SLAM研究(Matlab代码实现)内容概要:本文介绍了利用扩展卡尔曼滤波器(EKF)结合高斯过程回归(GPR)进行磁场辅助的SLAM(同步定位与地图构建)研究,并提供了完整的Matlab代码实现。该方法通过高斯过程回归对磁场空间进行建模,有效捕捉磁场分布的非线性特征,同时利用扩展卡尔曼滤波器融合传感器数据,实现移动机器人在复杂环境中的精确定位与地图构建。研究重点在于提升室内等无GPS环境下定位系统的精度与鲁棒性,尤其适用于磁场特征明显的场景。文中详细阐述了算法原理、数学模型构建、状态估计流程及仿真实验设计。; 适合人群:具备一定Matlab编程基础,熟悉机器人感知、导航或状态估计相关理论的研究生、科研人员及从事SLAM算法开发的工程师。; 使用场景及目标:①应用于室内机器人、AGV等在缺乏GPS信号环境下的高精度定位与地图构建;②为磁场SLAM系统的设计与优化提供算法参考和技术验证平台;③帮助研究人员深入理解EKF与GPR在非线性系统中的融合机制及实际应用方法。; 阅读建议:建议读者结合Matlab代码逐模块分析算法实现细节,重点关注高斯过程回归的训练与预测过程以及EKF的状态更新逻辑,可通过替换实际磁场数据进行实验验证,进一步拓展至多源传感器融合场景。
<think> 我们正在讨论一个包含policy_net成员的DQNAgent类,需要调用policy_net的forward方法。 根据引用[5](ray.rllib-入门实践-12-2)的提示,在RLlib自定义模型时,我们可以通过config.model["custom_model_config"]传递参数,但这里我们更关心如何调用forward方法。 通常,在PyTorch中,一个神经网络模块(nn.Module)的forward方法可以通过直接调用模块对象来触发,因为__call__方法会调用forward。例如:output = policy_net(input_tensor) 但是,在RLlib的框架中,我们可能需要在Policy类的函数中使用这个模型。引用[1]提到我们可以自定义TorchPolicy类,并替换其中的函数。例如,在计算动作的函数中,我们会调用模型的forward方法。 因此,在DQNAgent类中,如果我们有一个policy_net成员(它是一个torch.nn.Module的子类实例),那么调用forward方法的方式就是直接传入输入张量,如: output = self.policy_net(observation_tensor) 但需要注意的是,在RLlib中,我们通常不会直接调用policy_net,而是通过Policy类的compute_actions方法。然而,如果我们确实需要直接调用(例如在自定义的训练步骤中),那么就像上面那样调用。 另外,引用[3]提到训练终止条件,但这与当前问题关系不大。 所以,总结回答:在DQNAgent类中,调用policy_net的forward方法,可以直接将输入张量传递给policy_net对象,因为PyTorch模块的__call__会调用forward。 示例代码: # 假设我们有一个DQNAgent类 class DQNAgent: def __init__(self, ...): self.policy_net = ... # 初始化一个神经网络 def get_action(self, observation): # 将observation转换为tensor,然后通过policy_net前向传播 obs_tensor = torch.tensor(observation, dtype=torch.float32) # 调用forward方法:直接使用policy_net(obs_tensor) q_values = self.policy_net(obs_tensor) action = q_values.argmax().item() return action 注意:在实际使用中,我们可能需要考虑batch维度(例如,obs_tensor可能是[1, state_dim])以及是否使用梯度(如评估时不需要梯度,可以用with torch.no_grad())。 根据引用[1]和[5],在RLlib自定义Policy中,我们通常会在覆盖的函数(如`forward`)中调用模型的forward方法,但具体到用户的问题,是在一个包含policy_net成员的DQNAgent类中调用,所以直接调用即可。 因此,回答如下: </think> 在包含 `policy_net` 成员的 `DQNAgent` 类中调用其 `forward` 方法的正确方式如下: ### 调用方式 1. **直接调用**(推荐) 通过实例化对象直接调用神经网络模块,PyTorch 会自动触发 `forward` 方法: ```python # 输入观测状态(需转换为张量) observation_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) # 添加 batch 维度 # 调用 forward 方法 q_values = agent.policy_net(observation_tensor) ``` 2. **显式调用** 直接调用 `forward` 方法(需手动处理梯度): ```python with torch.no_grad(): # 推理时通常不需要梯度 q_values = agent.policy_net.forward(observation_tensor) ``` ### 关键说明 - **输入格式**:`forward` 要求输入为 `torch.Tensor` 且包含 batch 维度(即使 batch_size=1)[^1][^5]。 - **前向传播流程**: ```mermaid graph LR A[原始观测 obs] --> B[转换为 Tensor] B --> C[添加 batch 维度] C --> D{调用 policy_net} D --> E[输出 Q 值张量] ``` - **RLlib 集成**:在自定义 Policy 类中,通常通过覆盖 `compute_actions` 方法间接调用 `forward`[^1][^5]: ```python class CustomPolicy(TorchPolicy): def compute_actions(self, obs_batch, ...): logits, _ = self.model({"obs": obs_batch}) # 内部调用 forward return torch.argmax(logits, dim=1) ``` ### 示例代码 ```python import torch class DQNAgent: def __init__(self, state_dim, action_dim): self.policy_net = torch.nn.Linear(state_dim, action_dim) # 示例网络 def predict(self, obs): obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) return self.policy_net(obs_tensor) # 自动调用 forward # 使用示例 agent = DQNAgent(state_dim=4, action_dim=2) observation = [0.5, -0.2, 0.1, 0.7] q_values = agent.predict(observation) # 输出: tensor([[0.12, -0.05]]) ``` ### 注意事项 - **梯度管理**:训练时需保留梯度,评估时应使用 `torch.no_grad()`[^5]。 - **设备兼容**:若模型在 GPU 上,需将输入数据移至相同设备: ```python obs_tensor = obs_tensor.to(next(agent.policy_net.parameters()).device) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值