import os
from ray.rllib.algorithms.ppo import PPO,PPOConfig
config = PPOConfig()
config = config.environment(env="CartPole-v1",env_config = {})
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0, num_cpus_per_worker=1, num_gpus_per_worker=0)
algo = config.build()
################## 获取模型网络权重 ############################
"""
说明:
rllib中, actor 网络和 critic 网络被统一整合到一个 继承自 ModelV2 或 TorchModelV2 的类下,
所以 rllib 中的模型指的是这个继承自 ModelV2 或 TorchModelV2 的类,
里面即包含了所有用到的网络。
"""
## (推荐)方法1: 获取默认的本地的 model 的网络权重
weights = algo.get_policy().get_weights()
print(f"weights = {weights}")
print(f"weights_keys = {weights.keys()}")
## 方法2: 获取默认的本地的 model 的网络权重 , 等价于方法 1.
# weights = algo.workers.local_worker().policy_map["default_policy"].get_weights()
# print(f"weights = {weights}")
# print(f"weights_keys = {weights.keys()}")
## 方法3: 获取每一个worker上的网络副本的网络权重, 包含远程副本
# workers_weights = algo.workers.foreach_worker(lambda worker: worker.get_policy().get_weights())
# print(f"workers_weights = {workers_weights}")
# print(f"worker_0_weights = {workers_weights[0]}")
# print(f"worker_0_weights.keys = {workers_weights[0].keys()}")
# print(f"worker_num = {len(workers_weights)}")
## 方法4: 获取每一个worker上的网络副本的网络权重, 包含远程副本
# workers_weights = algo.workers.foreach_worker_with_id(
# lambda _id, worker: worker.get_policy().get_weights()
# )
# print(f"workers_weights = {workers_weights}")
# print(f"worker_0_weights = {workers_weights[0]}")
# print(f"worker_0_weights.keys = {workers_weights[0].keys()}")
# print(f"worker_num = {len(workers_weights)}")
## 方法5: 获取 worker 的 state
# workers_states = algo.workers.foreach_worker_with_id(
# lambda _id, worker: worker.get_state()
# )
# print(f"workers_states = {workers_states}")
# print(f"workers_states[0].keys = {workers_states[0].keys()}")
ray.rllib-入门实践-9: 查看模型的网络权重和状态
最新推荐文章于 2025-12-26 11:14:43 发布
880

被折叠的 条评论
为什么被折叠?



