强化学习(一)--Sarsa与Q-learning算法
最近实验室有一个项目要用到强化学习,在这开个新坑来记录下强化学习的学习过程。
第一节就先来最简单的基于表格型的RL算法,包括经典的Sarsa和Q-learning算法。
由于时间原因,关于算法的理论知识不再详细介绍,重点是研究怎么编程实现,代码是参考的飞浆PaddlePaddle公开课的代码,下来又自己手撸了一遍。飞浆PaddlePaddle公开课是我认为最适合入门强化学习的公开课,科老师讲解的真的非常清晰,公开课地址。
1. SARSA算法
sarsa算法是最基础的on-policy算法,它采用的是TD单步更新的方式,每一个step都会更新Q表格,Q表格的更新公式为:这也是代码最核心的部分,它就是将Q值不断逼近目标值,也就是未来总收益。
Sarsa的名字就来源于它更新Q表格时所用到的五个参数:S,A,R,S’,A’,它的算法伪代码为:
第一次看伪代码可能会有些懵,公开课里很贴心的给出了流程图:
 和测试函数test_episode() 也是接下来要实现的。
- 共进行500个episode的训练,每个episode都输出进行多少步和总的reward,每20个episode,我们输出可视化一下。
- 训练结束后我们测试下结果。
# 主函数
def main():
# 导入环境
env = gym.make("CliffWalking-v0")
env = CliffWalkingWapper(env)
# env = gym.make("FrozenLake-v0",is_slippery = False)
# env = FrozenLakeWapper(env)
agent = SarsaAgent(
obs_n = env.observation_space.n,
act_n = env.action_space.n,
learning_rate = 0.1,
gamma = 0.9,
e_greed = 0.1)
is_render = False
# 进行500个轮次的训练
for episode in range(1000):
ep_reward