0 简介
baselines是OpenAI推出的一套强化学习算法组件,用于快速配置强化学习算法,对入门者比较友好
1 安装
pip install stable-baselines
2 参数介绍
Base RL Class
common interface for all the RL algorithms
class stable_baselines.common.base_class.BaseRLModel(policy,env,verbos=0,*,requires_vec_env,policy_base,policy_kwargs=None,seed=None,n_cpu_tf_sess=None)
The base RL model
Parameters:
policy - ( BasePolicy )Policy object
policy: 策略模型选择,用于建立状态/状态-动作对和策略之间的联系,底层是多层感知机或卷积网络。
env: [Gym environment] The environment to learn from [if registered in Gym, can be str. Can be None for loading trained models]
env:
必要方法:step(action)、reset()、render()
必要元素:action_space、observation_space
step(action):仿真的步进,如何接受一个action然后进行一步仿真
reset():重置
render():显示
action_space:连续的/离散的。比如离散的,东西南北四个方向。比如连续的,选择一个区间产生一个数字,作为他的一个运动步长。
observation_space也是一样的,比如机器人右6个关节,他的状态就用他的6个位置和6个速度来表示。这个速度和位置都有一个上下线的范围。这个范围就可以作为他的observation_space
满足上述5个要素的存在的话,这个环境就可以传送到stable-baselines里面进行下一步的训练了
应用
通过stable_baselines建立DQN框架,训练并运行倒立摆(CartPole-v0)
from stable_baselines import DQN
from stable_baselines.common.evaluation import evaluate_policy
import gym
import time
env = gym.make('CartPole-v0') # 传入倒立摆
TRAIN = 0
if TRAIN: # 训练的部分
model = DQN('MlpPolicy', env, learning_rate=1e-3, prioritized_replay=True, verbose=1) # 属于一个可接受离散的这样的一个网络
# MlpPolicy,多层感知机或者神经网络的一个策略
# env,传入的一个环境
# 其他的一些参数,到这个文件夹下面去看一下,不想细说了,每一个都有一个详细的解释
model.learn(total_timesteps=int(1e5)) # 开始训练,直接用model.learn就可以了,这个learn中也会涉及一些参数
model.save("dqn_cartpole") #训练之后呢,就可以保存这样的一个模型
del model #训练结束后,这个模型就用不到了,就可以删掉了
else: # 演示的部分
model = DQN.load("dqn_cartpole", env) # 调用已经训练好的模型,从神经网络中调用
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
obs = env.reset() # 状态重置
for i in range(1000):
action, _states = model.predict(obs) # 将当前的状态传入,告诉我们会做出什么样的动作。
obs, rewards, done, info = env.step(action)# 返回一个新的状态
env.render() # 做一个显示
time.sleep(2) # for showing render()