【动手学强化学习】part5-值函数近似算法

阐述、总结【动手学强化学习】章节内容的学习情况,复现并理解代码。


一、算法背景

1.1算法目标

给定“黑盒”环境,求解最优policy

1.2问题

前序章节中以MC或TD方法构建model-free算法,以求解“黑盒”模型下的最优policy,但在action value(q(s,a))估计以Q_table的表格形式存储记录,在状态、动作空间较小时能够很好适应,但状态、动作空间扩大以后,算法运行时将承载巨大的存储压力。更甚者,当状态或者动作连续的时候,就有无限个状态动作对,我们更加无法使用这种表格形式来记录各个状态动作对的q(s,a)值。

1.3解决方法

  • 🌟函数拟合
    需要用函数拟合的方法来估计q(s,a)值,即将这个复杂的q(s,a)值表格视作数据,使用一个参数化的函数q(s,a,θ)来拟合这些数据
    θ为拟合函数的参数。
    这种函数拟合的方法存在一定的精度损失,因此被称为值函数近似方法(function approximation)
    ✅即将介绍的 DQN 算法便可以用来解决连续状态下离散动作的问题。

二、DQN算法

  • 🌟算法类型
    环境依赖:❌model-based ✅model-free
    价值估计:✅non-incremental ❌incremental(当replay_buffer数据的数量超过一定值后,才进行Q网络训练,并进行Q值估计)
    价值表征:❌tabular representation ✅function representation(不再基于Q-table方式存储q(s,a)值,而是采用“函数拟合”)
    学习方式:❌on-policy ✅off-policy
    策略表征:✅value-based ❌policy-based

2.1必要说明

Q网络建模

本节算法需要通过值函数近似的方法进行Q(s,a)的估计,相对于线性函数拟合,深度神经网络在函数拟合方面有更高的精度,因此鉴于神经网络具有强大的表达能力,因此我们可以用一个神经网络来表示动作价值函数Q(s,a)。
一般而言Q网络的输入输出有三种常见建模方式:
①输入:(s,a),输出:标量Q
②输入:s,输出:所有动作空间的Q
③输入:s,输出:max_Q

经验回放

在一般的有监督学习中,假设训练数据是独立同分布的,我们每次训练神经网络的时候从训练数据中随机采样一个或若干个数据来进行梯度下降,随着学习的不断进行,每一个训练数据会被使用多次
在原来的 Q-learning 算法中,每一个数据只会用来更新一次q值。
DQN 算法采用了经验回放(experience replay)方法,具体做法为维护一个回放缓冲区(replay buffer),将每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,训练 Q 网络的时候再从回放缓冲区中随机采样若干数据来进行训练。

  • 经验回放(experience replay)优势
    使样本满足独立假设。在 MDP 中交互采样得到的数据本身不满足独立假设,因为这一时刻的状态和上一时刻的状态有关。非独立同分布的数据对训练神经网络有很大的影响,会使神经网络拟合到最近训练的数据上。采用经验回放可以打破样本之间的相关性,让其满足独立假设。
    提高样本效率
    。每一个样本可以被使用多次,十分适合深度神经网络的梯度学习。

双Q网络更新

回顾基于时序差分的更新Q(s,a)的过程
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s^{\prime},a^{\prime})-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]
于是,对于采样得到的N个数据{(s,a,r,s’)} ,我们可以很自然地将 Q 网络的损失函数构造为均方误差的形式:
ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_\omega\frac1{2N}\sum_{i=1}^N\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a^{\prime}}Q_\omega\left(s_i^{\prime},a^{\prime}\right)\right)\right]^2 ω=argωmin2N1i=1N[Qω(si,ai)(ri+γamaxQω(si,a))]2

由于DQN 算法最终更新的目标是让 Q ω ( s , a ) Q_\omega(s,a) Qω(s,a)逼近TD target( r + γ max ⁡ a ′ Q ω ( s ′ , a ′ ) r+\gamma\operatorname*{max}_{a^{\prime}}Q_{\omega}\left(s^{\prime},a^{\prime}\right) r+γmaxaQω(s,a)),但由于TD error内本身就包含神经网络的输出,因此在更新网络参数的同时目标也在不断地改变,这非常容易造成神经网络训练的不稳定性

  • 解决方法
    使用双Q网络,即main networktarget network
    用main network去估计 Q ω ( s , a ) Q_\omega(s,a) Qω(s,a)
    用target network去估计 m a x a ′ Q ω ( s ′ , a ′ ) {max}_{a^{\prime}}Q_{\omega}\left(s^{\prime},a^{\prime}\right) maxaQω(s,a)

2.2伪代码

在这里插入图片描述

  • 算法流程简述:
    ①初始化Q网络模型:设置main_network,target_network网络层数、神经元个数、激活函数等;
    ②初始化“经验回放池”:设置经验回放池大小,即样本(s,a,r,s’,done)的个数,done为标志位,表示是否达到terminal state;
    ③采样填“池”:根据main_network不断step()获取样本存放至经验回放池;
    ④训练main_network:根据main_network估计(s,a)的Q值,根据target_network估计s’的最优Q值(即TD_target),将main_network的损失函数设置为TD_error的均方误差,并训练更新main_network;
    ⑤更新target_network:当main_network更新次数达到设置的阈值(例如:count=10)后,将main_network的网络参数复制给target_network。

2.3算法代码

import random
import gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils


class ReplayBuffer:
    ''' 经验回放池 '''

    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出

    def add(self, state, action, reward, next_state, done):  # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):  # 从buffer中采样数据,数量为batch_size
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):  # 目前buffer中数据的数量
        return 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值