【强化学习】基于tensorflow2.x 的 PPO2(离散动作情况) 训练 CartPole-v1

该代码示例展示了如何在Python中使用TensorFlow库构建和训练一个Actor-Critic模型来解决OpenAIGym的CartPole-v1环境。Actor网络用于选择动作,Critic网络估计状态值函数,通过强化学习更新策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

算法流程

在这里插入图片描述

代码

import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np

import gym
import copy

def build_actor_network(state_dim, action_dim):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dense(units=action_dim, activation='softmax')
    ])
    model.build(input_shape=(None, state_dim))
    return model

def build_critic_network(state_dim):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dense(units=1, activation='linear')
    ])
    model.build(input_shape=(None, state_dim))
    return model

class Actor(object):
    def __init__(self, state_dim, action_dim, lr):
        self.action_dim = action_dim
        self.old_policy = build_actor_network(state_dim, action_dim)
        self.new_policy = build_actor_network(state_dim, action_dim)
        self.update_policy()

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    def choice_action(self, state):
        policy = tf.stop_gradient(self.old_policy(
            np.array([state])
        )).numpy()[0]
        return np.random.choice(
            self.action_dim,
            p=policy
        ), policy

    def update_policy(self):
        self.old_policy.set_weights(
            self.new_policy.get_weights()
        )

    def learn(self, batch_state, batch_action, advantage, epsilon=0.2):
        advantage = np.reshape(advantage, newshape=(-1))
        batch_action = tf.stack([tf.range(tf.shape(batch_action)[0], dtype=tf.int32), batch_action], axis=1)
        old_policy = self.old_policy(batch_state)
        with tf.GradientTape() as tape:
            new_policy = self.new_policy(batch_state)

            pi_prob = tf.gather_nd(params=new_policy, indices=batch_action)
            oldpi_prob = tf.gather_nd(params=old_policy, indices=batch_action)
            ratio = pi_prob / (oldpi_prob + 1e-6)
            surr1 = ratio * advantage
            surr2 = tf.clip_by_value(ratio, clip_value_min=1.0 - epsilon, clip_value_max=1.0 + epsilon) * advantage
            loss = - tf.reduce_mean(tf.minimum(surr1, surr2))
        grad = tape.gradient(loss, self.new_policy.trainable_variables)
        self.optimizer.apply_gradients(zip(grad, self.new_policy.trainable_variables))

    def save_weights(self, path):
        self.old_policy.save_weights(path)

    def load_weights(self, path):
        self.old_policy.load_weights(path)
        self.new_policy.load_weights(path)

class Critic(object):
    def __init__(self, state_dim, lr):
        self.value = build_critic_network(state_dim)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    def get_advantage(self, state, reward):
        return reward - self.value.predict(state, verbose=0)

    def get_value(self, state):
        return self.value.predict(
            state,
            verbose=0
        )

    def learn(self, batch_state, batch_raward):
        with tf.GradientTape() as tape:
            value_predict = self.value(batch_state)
            loss = tf.keras.losses.mean_squared_error(batch_raward, value_predict)
        grad = tape.gradient(loss, self.value.trainable_variables)
        self.optimizer.apply_gradients(zip(grad, self.value.trainable_variables))

    def save_weights(self, path):
        self.value.save_weights(path)

    def load_weights(self, path):
        self.value.load_weights(path)

if __name__ == '__main__':
    episodes = 200
    env = gym.make("CartPole-v1")
    A_learning_rate = 1e-3
    C_learning_rate = 1e-3
    actor = Actor(4, 2, A_learning_rate)
    critic = Critic(4, C_learning_rate)
    gamma = 0.9
    lam = 0.98
    assert lam >= 0.0 and lam <= 1.0, "lammbda 必须介于(0,1)"
    K_epoch = 10
    assert K_epoch > 1, "K_epoch必须大于1,不然计算的重要性采样没有意义"

    plot_score = []
    for e in range(episodes):
        state = env.reset()
        S, A, R, nS = [], [], [], []
        score = 0.0
        while True:
            action, policy = actor.choice_action(state)
            next_state, reward, done, _ = env.step(action)
            score += reward
            S.append(state)
            A.append(action)
            R.append(reward)
            nS.append(next_state)
            state = copy.deepcopy(next_state)
            if done:
                discounted_r = []
                tmp_r = 0.0
                v_nS = critic.get_value(np.array(nS, dtype=np.float))
                v_nS[-1] = 0
                for r, vs in zip(R[::-1], v_nS[::-1]):
                    tmp_r = r + gamma * (lam * tmp_r + (1 - lam) * vs[0])
                    discounted_r.append(np.array([tmp_r]))
                discounted_r.reverse()

                bs = np.array(S, dtype=np.float)
                ba = np.array(A)
                br = np.array(discounted_r, dtype=np.float)

                advantage = critic.get_advantage(bs, br)
                advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-6)
                for k in range(K_epoch):
                    actor.learn(bs, ba, advantage)
                    critic.learn(bs, br)
                actor.update_policy()
                print("episode: {}/{}, score: {}".format(e + 1, episodes, score))
                break
        plot_score.append(score)
    plt.plot(plot_score)
    plt.show()

训练200轮奖励(reward)变化

请添加图片描述

树莓派(Raspberry Pi)是一种小型、低成本的电脑单板,如果你想在树莓派上通过国内源升级软件或安装软件包,因为默认的是使用Raspbian系统,这个系统通常依赖于英国的软件仓库。要更换为国内源以提高下载速度和稳定性,你可以按照以下步骤操作: 1. **关闭自动更新**:首先,为了避免正在更新过程中切换源导致的问题,打开终端并运行命令 `sudo raspi-config` ,选择 `Interfacing Options` -> `SSH`,然后按`e`键编辑,将`Automatically enable SSH`设置为no。 2. **修改软件源列表**:编辑 `/etc/apt/sources.list` 文件,可以使用文本编辑器如`nano`打开它。找到 `deb http://镜像地址/ raspbian/` 这一行,替换为国内镜像源,比如阿里云或清华大学的源: - 阿里云源:`deb http://mirrors.aliyun.com/raspbian/ buster main` - 清华大学源:`deb http://mirrors.tuna.tsinghua.edu.cn/raspbian/ buster main` 3. **添加国内apt镜像源**:如果需要更多软件包,可以在文件末尾添加国内的 apt 存储库,例如: ``` deb http://mirrors.ustc.edu.cn/debian stretch main contrib non-free deb-src http://mirrors.ustc.edu.cn/debian stretch main contrib non-free ``` 4. **更新密钥**:切换到新源后,需要更新系统的关键信息以便信任新的服务器。运行: ```bash sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys [对应的新密钥ID] ``` 5. **刷新并安装更新**:最后,输入 `sudo apt update` 更新索引,然后 `sudo apt upgrade` 升级系统到新源中的最新版本。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值