【强化学习】MuZero 训练CartPole-v1

该博客详细介绍了如何在TensorFlow2.x环境下复现MuZero算法,应用于CartPole-v1环境。作者提供了完整的代码实现,包括MCTS的Atari和棋类游戏版本,以及模型的训练过程。代码经过优化,支持多线程训练,展示了训练过程中的奖励曲线。此外,还包含了训练和模型更新的损失函数计算。

【深度强化学习】tensorflow2.x复现 muzero 训练CartPole-v1

参考资料
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.

请添加图片描述
关于代码

  • 这里先道个歉,因为考研的缘故之前复现的那个 muzero 没有完成,并且代码也因为换电脑不小心搞丢了(GitHub里面的代码全是报错,写的太乱也懒得改了)。。。。。。这里重新写一遍,比之前的那个代码更加简单易读。
  • 蒙特卡洛树的代码 (MCTS.py) 部分这次将用于 Atari 游戏的蒙特卡洛树搜索和用于 棋类游戏(chess)的蒙特卡洛树搜索分开写,下面的代码里完成了用于 Atari 游戏的蒙特卡洛树搜索,至于如果有需求要写用于棋类游戏的 muzero 可以自行完成(可以根据 Atari 的 MCTS 更改,这次的代码非常容易读,照着 Atari 的 MCTS 去写非常简单)
  • 本次的代码是完整的muzero代码(至少是对gym里面的游戏),可以根据需求自行更改。

更新

  • 补全了棋类的MCTS,加入了探索噪声。
  • 更新了多线程训练方法(对训练函数trainer.py略有修改)

cartpole-v1_muzero.py

from cartpole_v1_model import linner_model
from trainer import ReplayBuffer, Trainer
from MCTS import MCTS_Atari
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import copy
import gym

# 配置GPU内存
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

ENV_NAME = "CartPole-v1"
OBSERVATION_SIZE = 4
ACTION_SIZE = 2

DISCOUNT = 0.9
NUM_SIMULATIONS = 50
UNROLL_STEPS = 9
MEMORY_SIZE = int(1e6)
SIMPLE_SIZE = 1024
EPISODES = 1000

class muzero:
    def __init__(self, observation_size, action_size):
        self.model = linner_model(observation_size, action_size)
        self.MCTS = MCTS_Atari

    def choice_action(self, observation, T=1.0):
        MCTS = self.MCTS(self.model, observation)
        visit_count, MCTS_value = MCTS.simulations(NUM_SIMULATIONS, DISCOUNT)
        visit_counts = list(visit_count.values())
        prob = np.array(visit_counts) ** (1 / T) / np.sum(np.array(visit_counts) ** (1 / T))
        return np.random.choice(len(prob), p=prob), prob, MCTS_value

    def plot_score(self, scores):
        plt.plot(scores)
        plt.show()

if __name__ == '__main__':
    env = gym.make(ENV_NAME)
    agent = muzero(OBSERVATION_SIZE, ACTION_SIZE)

    trainer = Trainer(discount=DISCOUNT)
    replay_buffer = ReplayBuffer(MEMORY_SIZE, UNROLL_STEPS)

    scores = []
    for e in range(EPISODES):
        state = env.reset()
        action_next, policy_next, _ = agent.choice_action(state)
        rewards = 0
        while True:
            env.render()
            action, policy = action_next, policy_next
            next_state, reward, done, _ = env.step(action)
            action_next, policy_next, value_next = agent.choice_action(next_state)
			reward = -100 if done else reward
            done = 1 if done else 0

            rewards += reward
            action_onehot = np.array([1 if i == action else 0 for i in range(ACTION_SIZE)])
            replay_buffer.save_memory(state, policy, action_onehot, reward, value_next, next_state, done)
            state = copy.deepcopy(next_state)

            if done: break
        
		rewards += 100
        scores.append(rewards)
        policy_loss, value_loss, reward_loss = trainer.update_weight(agent.model, replay_buffer, SIMPLE_SIZE)

        print("episode: {}/{}, policy_loss: {}, value_loss: {}, reward_loss: {}, score: {}".format(
            e + 1, EPISODES, policy_loss, value_loss, reward_loss, rewards))

    agent.plot_score(scores)

cartpole_v1_model.py

from tensorflow.keras import layers, Input, Model
import numpy as np

class representation:
    def __init__(self, observation_size):
        observation = Input(shape=(observation_size))
        x = layers.Flatten()(observation)
        x = layers.Dense(units=128, activation='relu')(x)
        x = layers.Dense(units=128, activation='relu')(x)
        hidden_state = layers.Dense(units=observation_size)(x)
        self.model = Model(inputs=observation, outputs=hidden_state)
        # self.model.summary()
        self.trainable_variables = self.model.trainable_variables

    def predict(self, observation):
        observation = np.array([observation])
        hidden_state = np.array(self.model(observation)[0])
        return hidden_state

class dynamics:
    def __init__(self, observation_size, action_size):
        self.action_size = action_size
        hidden_state = Input(shape=(observation_size))
        action = Input(shape=(action_size))
        x = layers.concatenate([hidden_state, action])
        x = layers.Dense(units=128, activation='relu')(x)
        x = layers.Dense(units=128, activation='relu')(x)
        next_hidden_state = layers.Dense(units=observation_size)(x)
        reward = layers.Dense(units=1)(x)
        self.model = Model(inputs=[hidden_state, action], outputs=[next_hidden_state, reward])
        # self.model.summary()
        self.trainable_variables = self.model.trainable_variables

    def predict(self, hidden_state, action):
        hidden_state = np.array([hidden_state])
        action = np.array([[1 if i == action else 0 for i in range(self.action_size)]])
        next_hidden_state, reward = self.model([hidden_state, action])
        next_hidden_state = np.array(next_hidden_state[0])
        reward = np.array(reward[0][0])
        return next_hidden_state, reward

class prediction:
    def 
### 关于 CartPole-v1 和 PPQ 的实现与集成 CartPole-v1 是一个经典的强化学习环境,通常用于测试和验证强化学习算法的性能。PPQ 是一个高效的神经网络量化工具,主要用于深度学习模型的优化和部署。将 PPQ 与 CartPole-v1 结合,可以通过量化技术优化强化学习模型的推理速度和存储需求。 以下是一个简化的实现示例,展示如何将 PPQ 应用于 CartPole-v1强化学习模型: #### 使用 PPQ 对 CartPole-v1 模型进行量化 ```python import gym import torch import torch.nn as nn from stable_baselines3 import PPO import ppq from ppq import QuantizationSettingFactory, GraphExporter # 加载 CartPole-v1 环境 env = gym.make('CartPole-v1') # 创建并训练一个简单的强化学习模型(PPO) model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) # 导出训练好的模型为 ONNX 格式 dummy_input = torch.randn(1, 4) # 输入形状为 (batch_size, observation_space) torch.onnx.export(model.policy, dummy_input, "cartpole_model.onnx") # 加载 ONNX 模型到 PPQ onnx_model = ppq.load_onnx_model('cartpole_model.onnx') # 创建量化配置 quant_setting = QuantizationSettingFactory.default_setting() # 执行量化 executor = ppq.PPQExecutor(graph=onnx_model, setting=quant_setting) quantized_model = executor.quantize() # 导出量化后的模型 exporter = GraphExporter() exporter.export(file_path='quantized_cartpole_model.onnx', graph=quantized_model, config=quant_setting) ``` 上述代码展示了如何使用 PPQ 对 CartPole-v1强化学习模型进行量化[^1]。具体步骤包括加载环境、训练模型、导出为 ONNX 格式、加载到 PPQ 并执行量化,最后导出量化后的模型。 #### 注意事项 1. 在实际应用中,可能需要对模型进行校准以确保量化后的精度损失在可接受范围内。 2. 强化学习模型的结构可能较为复杂,因此需要根据具体任务调整量化设置。 3. PPQ 支持多种硬件平台和推理框架,可以根据目标设备选择合适的后端[^2]。 ###
评论 17
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值