环境配置:
torch==2.5.1
ray==2.10.0
ray[rllib]==2.10.0
ray[tune]==2.10.0
ray[serve]==2.10.0
numpy==1.23.0
python==3.9.18
代码:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
## 配置算法
algo = (
PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.build()
)
## 训练模型. 每个 iter 里重复执行多次 episode. 直到满足条件, 比如新增采样量达到一定体量。
for i in range(3):
result = algo.train()
print(pretty_print(result))
## 保存模型
checkpoint_dir =

最低0.47元/天 解锁文章
1291

被折叠的 条评论
为什么被折叠?



