【强化学习】parl使用之——parl快速入门:解决CartPole问题

快速入门:解决CartPole问题https://parl.readthedocs.io/zh-cn/latest/tutorial/getting_started.html

CartPole又叫倒立摆。小车上放了一根杆,杆会因重力而倒下。
为了不让杆倒下,我们要通过移动小车,来保持其是直立的。如下图所示。

在每一个时间步,模型的输入是一个4维的向量,表示当前小车和杆的状态,模型输出的信号用于控制小车往左或者右移动。

当杆没有倒下的时候,每个时间步,环境会给1分的奖励;
当杆倒下后,环境不会给任何的奖励,游戏结束。
在这里插入图片描述

环境

windows11,cpu核心是锐龙3700,arm64架构

python 3.8.10

依赖的包体版本

absl-py==2.2.2
anyio==4.5.2
astor==0.8.1
blinker==1.8.2
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==1.6.0
colorama==0.4.6
decorator==4.4.2
exceptiongroup==1.2.2
filelock==3.16.1
Flask==3.0.3
Flask-Cors==5.0.0
fsspec==2025.3.0
gast==0.3.3
google-auth==2.39.0
google-auth-oauthlib==0.4.6
grpcio==1.37.0
gym==0.26.2
gym-notices==0.0.8
h11==0.14.0
httpcore==1.0.8
httpx==0.28.1
idna==3.10
importlib_metadata==8.5.0
itsdangerous==2.2.0
Jinja2==3.1.6
Markdown==3.7
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.1
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
paddle-bfloat==0.1.7
paddlepaddle==2.5.0
parl==2.2.1
pillow==10.4.0
protobuf==3.20.0
psutil==7.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
pygame==2.6.1
pynvml==11.5.3
pyzmq==18.1.1
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
scipy==1.10.0
six==1.17.0
sniffio==1.3.1
sympy==1.13.3
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5
termcolor==2.4.0
typing_extensions==4.13.2
urllib3==2.2.3
Werkzeug==3.0.6
zipp==3.20.2

代码实现:

直接去拷贝parl的example里面的文件就行,网址如下,
https://github.com/PaddlePaddle/PARL/tree/develop/examples/QuickStart
一共分为三个文件
在这里插入图片描述
下面给出的是train.py里面的全部代码,注意我对源码略有修改,所以才贴出来,
另外两个cartpole_agent和cartpole_mode不需要修改了,所以我没有贴出来。

#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import gym
import numpy as np
import parl
from parl.utils import logger
from parl.env import CompatWrapper, is_gym_version_ge
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
import argparse

LEARNING_RATE = 1e-3


# train an episode
def run_train_episode(agent, env):
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
        action = agent.sample(obs)
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list


# evaluate 5 episodes
def run_evaluate_episodes(agent, eval_episodes=5, render=False):
    # Compatible for different versions of gym
    if is_gym_version_ge("0.26.0") and render:  # if gym version >= 0.26.0
        env = gym.make('CartPole-v1', render_mode="human")
    else:
        env = gym.make('CartPole-v1')
    env = CompatWrapper(env)

    eval_reward = []
    for i in range(eval_episodes):
        obs = env.reset()
        episode_reward = 0
        while True:
            action = agent.predict(obs)
            obs, reward, isOver, _ = env.step(action)
            episode_reward += reward
            if render:
                env.render()
            if isOver:
                break
        eval_reward.append(episode_reward)
    return np.mean(eval_reward)


def calc_reward_to_go(reward_list, gamma=1.0):
    for i in range(len(reward_list) - 2, -1, -1):
        # G_i = r_i + γ·G_i+1
        reward_list[i] += gamma * reward_list[i + 1]  # Gt
    return np.array(reward_list)


def main():
    env = gym.make('CartPole-v1')
    # Compatible for different versions of gym
    env = CompatWrapper(env)
    # env = env.unwrapped # Cancel the minimum score limit
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))

    # build an agent
    model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
    print(model)
    alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
    agent = CartpoleAgent(alg)

    # load model and evaluate
    # if os.path.exists('./model.ckpt'):
    #     agent.restore('./model.ckpt')
    #     run_evaluate_episodes(agent, 5, render=True)
    #     exit()

    for i in range(args.max_episodes):
        obs_list, action_list, reward_list = run_train_episode(agent, env)
        if i % 10 == 0:
            logger.info("Episode {}, Reward Sum {}.".format(
                i, sum(reward_list)))

        batch_obs = np.array(obs_list)
        batch_action = np.array(action_list)
        batch_reward = calc_reward_to_go(reward_list)

        agent.learn(batch_obs, batch_action, batch_reward)
        if (i + 1) % 100 == 0:
            total_reward = run_evaluate_episodes(agent, eval_episodes=5, render=False)
            logger.info('Test reward: {}'.format(total_reward))

    # save the parameters to ./model.ckpt
    agent.save('./model.ckpt')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Environment
    parser.add_argument(
        '--max_episodes',
        type=int,
        default=1000,
        help='stop condition: number of episodes')
    args = parser.parse_args()
    main()


训练的输出:

[04-20 23:30:51 MainThread @utils.py:77] paddlepaddle version: 2.5.0.
[04-20 23:30:51 MainThread @train.py:84] obs_dim 4, act_dim 2
CartpoleModel(
  (fc1): Linear(in_features=4, out_features=20, dtype=float32)
  (fc2): Linear(in_features=20, out_features=2, dtype=float32)
)
[04-20 23:30:51 MainThread @train.py:101] Episode 0, Reward Sum 29.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 10, Reward Sum 28.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 20, Reward Sum 67.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 30, Reward Sum 19.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 40, Reward Sum 42.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 50, Reward Sum 34.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 60, Reward Sum 11.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 70, Reward Sum 26.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 80, Reward Sum 35.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 90, Reward Sum 13.0.
[04-20 23:31:28 MainThread @train.py:111] Test reward: 177.6
[04-20 23:31:28 MainThread @train.py:101] Episode 100, Reward Sum 18.0.
[04-20 23:31:28 MainThread @train.py:101] Episode 110, Reward Sum 26.0.
[04-20 23:31:28 MainThread @train.py:101] Episode 120, Reward Sum 41.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 130, Reward Sum 78.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 140, Reward Sum 38.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 150, Reward Sum 49.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 160, Reward Sum 25.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 170, Reward Sum 23.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 180, Reward Sum 41.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 190, Reward Sum 34.0.
[04-20 23:32:17 MainThread @train.py:111] Test reward: 233.6
[04-20 23:32:17 MainThread @train.py:101] Episode 200, Reward Sum 107.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 210, Reward Sum 52.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 220, Reward Sum 126.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 230, Reward Sum 11.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 240, Reward Sum 41.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 250, Reward Sum 98.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 260, Reward Sum 122.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 270, Reward Sum 47.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 280, Reward Sum 76.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 290, Reward Sum 45.0.
[04-20 23:33:11 MainThread @train.py:111] Test reward: 258.8
[04-20 23:33:11 MainThread @train.py:101] Episode 300, Reward Sum 14.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 310, Reward Sum 85.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 320, Reward Sum 17.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 330, Reward Sum 92.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 340, Reward Sum 110.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 350, Reward Sum 180.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 360, Reward Sum 71.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 370, Reward Sum 135.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 380, Reward Sum 125.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 390, Reward Sum 160.0.
[04-20 23:34:34 MainThread @train.py:111] Test reward: 400.0
[04-20 23:34:34 MainThread @train.py:101] Episode 400, Reward Sum 117.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 410, Reward Sum 75.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 420, Reward Sum 153.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 430, Reward Sum 296.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 440, Reward Sum 45.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 450, Reward Sum 183.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 460, Reward Sum 108.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 470, Reward Sum 12.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 480, Reward Sum 190.0.
[04-20 23:34:36 MainThread @train.py:101] Episode 490, Reward Sum 34.0.
[04-20 23:36:19 MainThread @train.py:111] Test reward: 500.0
[04-20 23:36:19 MainThread @train.py:101] Episode 500, Reward Sum 156.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 510, Reward Sum 119.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 520, Reward Sum 178.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 530, Reward Sum 132.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 540, Reward Sum 308.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 550, Reward Sum 155.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 560, Reward Sum 8.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 570, Reward Sum 500.0.
[04-20 23:36:21 MainThread @train.py:101] Episode 580, Reward Sum 154.0.
[04-20 23:36:21 MainThread @train.py:101] Episode 590, Reward Sum 62.0.

跑完后,进行可视化测试:

把下图中这行代码取消注释, 然后再运行一下,

(不知道是哪个傻X写的注释,参数少传了一个,在上面的代码里面,我已经修改添加了)
在这里插入图片描述

EVALUATE阶段的运行展示:

在这里插入图片描述

输出训练曲线:

from parl.utils import summary
summary.add_scalar('Train Reward Sum', sum(reward_list), i)

在这里加一行代码,并导入对应的库:
并再运行一次训练,进行打点,得到数据,数据在train文件夹下面,
在这里插入图片描述
这个文件我们打不开,需要用正确方法打开:
在这里插入图片描述

先安装对应的库

pip install tensorboard -i https://pypi.douban.com/simple

如图
在这里插入图片描述

然后并且进入到这个目录,这个目录下存在这个文件
在这里插入图片描述

然后运行

tensorboard --logdir=. --port=8008 

运行完成后显示:
在这里插入图片描述
然后点击这个蓝色url进入查看,
左侧可以调整曲线数据平滑度,我调到0.99得到如图数据,可以看出来,
训练的成功十分显著,立杆的表现越来越好。

在这里插入图片描述

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值