ray.rllib-入门实践-8: 模型推断与评估

部署运行你感兴趣的模型镜像

模型的训练、保存、加载请参考前面的博客:

        ray.rllib 入门实践-5: 训练算法-优快云博客

        ray.rllib 入门实践-6: 保存模型-优快云博客

        ray.rllib 入门实践-7: 加载训练好的模型-优快云博客

本博客仅根据推荐的训练、保存、加载模型的方法产生并加载模型,然后介绍两种模型评估的方法。

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、模型的训练、保存

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print


## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)

config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1)  ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path  ## 设置过程文件的存储路径

## 构建算法
algo = config.build()

## 训练算法
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")

二、 模型评估

方式1:多轮统计评估

## 方式1: algo.evaluation() . 
## 本方法的前提是,在算法训练过程中,给算法配置上 evaluation 相关选项, 否则该方法失败。
## 本方法执行了多个 episode, 并对结果进行统计,返回统计结果。
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 评估模型
evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
print(pretty_print(evaluation_result))

方式2:单轮评估

import gymnasium as gym 
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False 
obs,info = env.reset()
while not terminated and not truncated:
    action = algo.compute_single_action(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    episode_reward += reward
    step += 1
    print(f"step = {step}, reward = {reward},\
          action = {action}, obs = {obs}, \
        episode_reward = {episode_reward}")

三、代码汇总

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print
import gymnasium as gym 


## 配置算法
storage_path = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs"
os.makedirs(storage_path,exist_ok=True)

config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config = config.evaluation(evaluation_num_workers=1)  ## 要想调用 algo 的evaluation功能,需要在这里进行设置,否则不work.
config.output = storage_path  ## 设置过程文件的存储路径

## 构建算法
algo = config.build()

## 训练算法
for i in range(3):
    result = algo.train() 
    print(f"episode_{i}")

## 保存模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
os.makedirs(checkpoint_dir,exist_ok=True)
algo.save_checkpoint(checkpoint_dir) ## 保存到指定路径下
print(f"saved checkpoint to {checkpoint_dir}")


#################  evaluate  #############################
# ## 方式1: algo.evaluation() . 执行了多个 episode, 并对结果进行统计,返回统计结果。
# ## 加载模型
# checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
# algo = PPO.from_checkpoint(checkpoint_dir)
# print(f"Loaded checkpoint from {checkpoint_dir}")
# ## 评估模型
# evaluation_result = algo.evaluate() ## 需要在算法训练阶段,给算法配置上 evaluation 选项, 否则此处调用失败
# print(pretty_print(evaluation_result))

## 方式 2:algo.compute_single_action(obs)
## 创建环境
env_name = "CartPole-v1"
env = gym.make(env_name)
## 加载模型
checkpoint_dir = "F:/codes/RLlib_study/ray_rllib_tutorials/outputs/checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"Loaded checkpoint from {checkpoint_dir}")
## 模型推断
step = 0
episode_reward = 0
terminated = truncated = False 
obs,info = env.reset()
while not terminated and not truncated:
    action = algo.compute_single_action(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    episode_reward += reward
    step += 1
    print(f"step = {step}, reward = {reward},\
          action = {action}, obs = {obs}, \
        episode_reward = {episode_reward}")

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值