建议先看博客 ray.rllib-入门实践-12-1:在自定义policy中注册使用自定义model , 本博客与之区别在于可以给自定义的 model 新增自定义的参数,并通过 config.model["custom_model_config"] 传入自定义的新增参数。
环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
示例代码:
import numpy as np
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from gymnasium import spaces
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
import ray
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from ray.rllib.algorithms.ppo import PPO, PPOConfig, PPOTorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import ModelV2
import torch
from typing import Dict, List, Type, Union
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.sample_batch import SampleBatch
## 1. 自定义模型 model
class CustomTorchModel(TorchModelV2, nn.Module):
def __init__(self, obs_space:gym.spaces.Space,
action_space:gym.spaces.Space,
num_outputs:int,
model_config:ModelConfigDict, ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
name:str,
*,custom_arg1,custom_arg2):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)

最低0.47元/天 解锁文章
2164

被折叠的 条评论
为什么被折叠?



