快速入门:解决CartPole问题https://parl.readthedocs.io/zh-cn/latest/tutorial/getting_started.html
CartPole又叫倒立摆。小车上放了一根杆,杆会因重力而倒下。
为了不让杆倒下,我们要通过移动小车,来保持其是直立的。如下图所示。
在每一个时间步,模型的输入是一个4维的向量,表示当前小车和杆的状态,模型输出的信号用于控制小车往左或者右移动。
当杆没有倒下的时候,每个时间步,环境会给1分的奖励;
当杆倒下后,环境不会给任何的奖励,游戏结束。

环境
windows11,cpu核心是锐龙3700,arm64架构
python 3.8.10
依赖的包体版本:
absl-py==2.2.2
anyio==4.5.2
astor==0.8.1
blinker==1.8.2
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==1.6.0
colorama==0.4.6
decorator==4.4.2
exceptiongroup==1.2.2
filelock==3.16.1
Flask==3.0.3
Flask-Cors==5.0.0
fsspec==2025.3.0
gast==0.3.3
google-auth==2.39.0
google-auth-oauthlib==0.4.6
grpcio==1.37.0
gym==0.26.2
gym-notices==0.0.8
h11==0.14.0
httpcore==1.0.8
httpx==0.28.1
idna==3.10
importlib_metadata==8.5.0
itsdangerous==2.2.0
Jinja2==3.1.6
Markdown==3.7
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.1
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
paddle-bfloat==0.1.7
paddlepaddle==2.5.0
parl==2.2.1
pillow==10.4.0
protobuf==3.20.0
psutil==7.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
pygame==2.6.1
pynvml==11.5.3
pyzmq==18.1.1
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
scipy==1.10.0
six==1.17.0
sniffio==1.3.1
sympy==1.13.3
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5
termcolor==2.4.0
typing_extensions==4.13.2
urllib3==2.2.3
Werkzeug==3.0.6
zipp==3.20.2
代码实现:
直接去拷贝parl的example里面的文件就行,网址如下,
https://github.com/PaddlePaddle/PARL/tree/develop/examples/QuickStart
一共分为三个文件

下面给出的是train.py里面的全部代码,注意我对源码略有修改,所以才贴出来,
另外两个cartpole_agent和cartpole_mode不需要修改了,所以我没有贴出来。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import numpy as np
import parl
from parl.utils import logger
from parl.env import CompatWrapper, is_gym_version_ge
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
import argparse
LEARNING_RATE = 1e-3
# train an episode
def run_train_episode(agent, env):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
action = agent.sample(obs)
action_list.append(action)
obs, reward, done, info = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
# evaluate 5 episodes
def run_evaluate_episodes(agent, eval_episodes=5, render=False):
# Compatible for different versions of gym
if is_gym_version_ge("0.26.0") and render: # if gym version >= 0.26.0
env = gym.make('CartPole-v1', render_mode="human")
else:
env = gym.make('CartPole-v1')
env = CompatWrapper(env)
eval_reward = []
for i in range(eval_episodes):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def calc_reward_to_go(reward_list, gamma=1.0):
for i in range(len(reward_list) - 2, -1, -1):
# G_i = r_i + γ·G_i+1
reward_list[i] += gamma * reward_list[i + 1] # Gt
return np.array(reward_list)
def main():
env = gym.make('CartPole-v1')
# Compatible for different versions of gym
env = CompatWrapper(env)
# env = env.unwrapped # Cancel the minimum score limit
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# build an agent
model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
print(model)
alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
agent = CartpoleAgent(alg)
# load model and evaluate
# if os.path.exists('./model.ckpt'):
# agent.restore('./model.ckpt')
# run_evaluate_episodes(agent, 5, render=True)
# exit()
for i in range(args.max_episodes):
obs_list, action_list, reward_list = run_train_episode(agent, env)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_reward_to_go(reward_list)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
total_reward = run_evaluate_episodes(agent, eval_episodes=5, render=False)
logger.info('Test reward: {}'.format(total_reward))
# save the parameters to ./model.ckpt
agent.save('./model.ckpt')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Environment
parser.add_argument(
'--max_episodes',
type=int,
default=1000,
help='stop condition: number of episodes')
args = parser.parse_args()
main()
训练的输出:
[04-20 23:30:51 MainThread @utils.py:77] paddlepaddle version: 2.5.0.
[04-20 23:30:51 MainThread @train.py:84] obs_dim 4, act_dim 2
CartpoleModel(
(fc1): Linear(in_features=4, out_features=20, dtype=float32)
(fc2): Linear(in_features=20, out_features=2, dtype=float32)
)
[04-20 23:30:51 MainThread @train.py:101] Episode 0, Reward Sum 29.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 10, Reward Sum 28.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 20, Reward Sum 67.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 30, Reward Sum 19.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 40, Reward Sum 42.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 50, Reward Sum 34.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 60, Reward Sum 11.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 70, Reward Sum 26.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 80, Reward Sum 35.0.
[04-20 23:30:51 MainThread @train.py:101] Episode 90, Reward Sum 13.0.
[04-20 23:31:28 MainThread @train.py:111] Test reward: 177.6
[04-20 23:31:28 MainThread @train.py:101] Episode 100, Reward Sum 18.0.
[04-20 23:31:28 MainThread @train.py:101] Episode 110, Reward Sum 26.0.
[04-20 23:31:28 MainThread @train.py:101] Episode 120, Reward Sum 41.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 130, Reward Sum 78.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 140, Reward Sum 38.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 150, Reward Sum 49.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 160, Reward Sum 25.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 170, Reward Sum 23.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 180, Reward Sum 41.0.
[04-20 23:31:29 MainThread @train.py:101] Episode 190, Reward Sum 34.0.
[04-20 23:32:17 MainThread @train.py:111] Test reward: 233.6
[04-20 23:32:17 MainThread @train.py:101] Episode 200, Reward Sum 107.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 210, Reward Sum 52.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 220, Reward Sum 126.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 230, Reward Sum 11.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 240, Reward Sum 41.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 250, Reward Sum 98.0.
[04-20 23:32:17 MainThread @train.py:101] Episode 260, Reward Sum 122.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 270, Reward Sum 47.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 280, Reward Sum 76.0.
[04-20 23:32:18 MainThread @train.py:101] Episode 290, Reward Sum 45.0.
[04-20 23:33:11 MainThread @train.py:111] Test reward: 258.8
[04-20 23:33:11 MainThread @train.py:101] Episode 300, Reward Sum 14.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 310, Reward Sum 85.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 320, Reward Sum 17.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 330, Reward Sum 92.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 340, Reward Sum 110.0.
[04-20 23:33:11 MainThread @train.py:101] Episode 350, Reward Sum 180.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 360, Reward Sum 71.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 370, Reward Sum 135.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 380, Reward Sum 125.0.
[04-20 23:33:12 MainThread @train.py:101] Episode 390, Reward Sum 160.0.
[04-20 23:34:34 MainThread @train.py:111] Test reward: 400.0
[04-20 23:34:34 MainThread @train.py:101] Episode 400, Reward Sum 117.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 410, Reward Sum 75.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 420, Reward Sum 153.0.
[04-20 23:34:34 MainThread @train.py:101] Episode 430, Reward Sum 296.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 440, Reward Sum 45.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 450, Reward Sum 183.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 460, Reward Sum 108.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 470, Reward Sum 12.0.
[04-20 23:34:35 MainThread @train.py:101] Episode 480, Reward Sum 190.0.
[04-20 23:34:36 MainThread @train.py:101] Episode 490, Reward Sum 34.0.
[04-20 23:36:19 MainThread @train.py:111] Test reward: 500.0
[04-20 23:36:19 MainThread @train.py:101] Episode 500, Reward Sum 156.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 510, Reward Sum 119.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 520, Reward Sum 178.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 530, Reward Sum 132.0.
[04-20 23:36:19 MainThread @train.py:101] Episode 540, Reward Sum 308.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 550, Reward Sum 155.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 560, Reward Sum 8.0.
[04-20 23:36:20 MainThread @train.py:101] Episode 570, Reward Sum 500.0.
[04-20 23:36:21 MainThread @train.py:101] Episode 580, Reward Sum 154.0.
[04-20 23:36:21 MainThread @train.py:101] Episode 590, Reward Sum 62.0.
跑完后,进行可视化测试:
把下图中这行代码取消注释, 然后再运行一下,
(不知道是哪个傻X写的注释,参数少传了一个,在上面的代码里面,我已经修改添加了)

EVALUATE阶段的运行展示:

输出训练曲线:
from parl.utils import summary
summary.add_scalar('Train Reward Sum', sum(reward_list), i)
在这里加一行代码,并导入对应的库:
并再运行一次训练,进行打点,得到数据,数据在train文件夹下面,

这个文件我们打不开,需要用正确方法打开:

先安装对应的库
pip install tensorboard -i https://pypi.douban.com/simple
如图

然后并且进入到这个目录,这个目录下存在这个文件

然后运行
tensorboard --logdir=. --port=8008
运行完成后显示:

然后点击这个蓝色url进入查看,
左侧可以调整曲线数据平滑度,我调到0.99得到如图数据,可以看出来,
训练的成功十分显著,立杆的表现越来越好。


3988

被折叠的 条评论
为什么被折叠?



