在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。
2. 向ray注册自定义的模型
3. 在config中配置使用自定义的模型
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
一、 定义自己的模型
需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。
def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
示例如下:
## 1. 定义自己的模型
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from gymnasium import spaces
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
class My_Model(TorchModelV2, nn.Module):
def __init__(self, obs_space:gym.spaces.Space,
action_space:gym.spaces.Space,
num_outputs:int,
model_config:ModelConfigDict, ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
name:str
,*, custom_arg1, custom_arg2):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)
nn.Module.__init__(self)
## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值
print(f"=========================== custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}

最低0.47元/天 解锁文章
1030

被折叠的 条评论
为什么被折叠?



