深度强化学习(DRL)五:优先回放DQN(Prioritized experience replay)

全部代码

https://github.com/ColinFred/Reinforce_Learning_Pytorch/tree/main/RL/DQN

一、优先回放

在经验回放中是利用均匀分布采样,而这种方式看上去并不高效,对于智能体而言,这些数据的重要程度并不一样,因此提出优先回放(Prioritized Replay)的方法。优先回放的基本思想就是打破均匀采样,赋予学习效率高的样本以更大的采样权重。

一个理想的标准是智能体学习的效率越高,权重越大。符合该标准的一个选择是TD偏差δ。TD偏差越大,说明该状态处的值函数与TD目标的差距越大,智能体的更新量越大,因此该处的学习效率越高。

简而言之,就是在原来的replay buffer中给每个Transition增加了抽样的优先级(priority)

优先回放DQN主要有三点改变:

1, 为了方便优先回放存储与及采样,采用sumTree树来存储;

原文有两种方法计算样本抽样概率:proportional priority和rank-based priority。proportional priority就是样本被sample到的概率是正比于TD偏差的priority;rank-based priority就是概率正比于Transition priority的排序(rank)。这里考虑proportional priority,Transition被抽到的概率与TD偏差成正比。

并且,为保证每一个存入的Transition都能被sample到,新Transition会被赋予一个很大的priority。

2, 目标函数在计算时根据样本的TD偏差添加了权重(权重和TD偏差有关,偏差越大,权重越大):
1 m ∑ j = 1 m w j ( y j − Q ( s j , a j , w ) ) 2 \frac{1}{m}\sum\limits_{j=1}^m w_j (y_j-Q(s_j, a_j, w))^2 m1j=1mwj(yjQ(sj,aj,w))2

3,每次更新Q网络参数时,都需要重新计算TD误差 δ j = y j − Q ( s j , a j , w ) \delta_j = y_j- Q(s_j, a_j, w) δj=yjQ(sj,aj,w)

二、代码

Prioritized experience replay 结合之前的 Double DQN 和 Dueling DQN

SumTree和ReplayMemory_Per

SumTree主要实现:add()添加experience;get()按priority抽样;update()更新某个Transition的priority。

ReplayMemory_Per主要实现:push()插入新experience;sample()按priority抽样Transition;update()更新已有经验的priority


class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])


class ReplayMemory_Per(object):
    # stored as ( s, a, r, s_ ) in SumTree
    def __init__(self, capacity=1000, a=0.6, e=0.01):
        self.tree = SumTree(capacity)
        self.memory_size = capacity
        self.prio_max = 0.1
        self.a = a
        self.e = e

    def push(self, *args):
        data = Transition(*args)
        p = (np.abs(self.prio_max) + self.e) ** self.a  # proportional priority
        self.tree.add(p, data)

    def sample(self, batch_size):
        idxs = []
        segment = self.tree.total() / batch_size
        sample_datas = []

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = uniform(a, b)
            idx, p, data = self.tree.get(s)

            sample_datas.append(data)
            idxs.append(idx)
        return idxs, sample_datas

    def update(self, idxs, errors):
        self.prio_max = max(self.prio_max, max(np.abs(errors)))
        for i, idx in enumerate(idxs):
            p = (np.abs(errors[i]) + self.e) ** self.a
            self.tree.update(idx, p)

    def size(self):
        return self.tree.n_entries

每次更新Q网络参数时,都需要重新计算TD误差,并且更新SumTree。

关于目标函数在计算时根据样本的TD偏差添加了权重这一点并未采用



class PerDQN:
    def __init__(self, n_action, n_state, learning_rate):

        self.n_action = n_action
        self.n_state = n_state

        self.memory = ReplayMemory_Per(capacity=100)
        self.memory_counter = 0

        self.model_policy = DNN(self.n_state, self.n_action)
        self.model_target = DNN(self.n_state, self.n_action)
        self.model_target.load_state_dict(self.model_policy.state_dict())
        self.model_target.eval()

        self.optimizer = optim.Adam(self.model_policy.parameters(), lr=learning_rate)

    def store_transition(self, s, a, r, s_):
        state = torch.FloatTensor([s])
        action = torch.LongTensor([a])
        reward = torch.FloatTensor([r])
        next_state = torch.FloatTensor([s_])
        self.memory.push(state, action, next_state, reward)

    def choose_action(self, state):
        state = torch.FloatTensor(state)
        if np.random.randn() <= EPISILO:  # greedy policy
            with torch.no_grad():
                q_value = self.model_policy(state)
                action = q_value.max(0)[1].view(1, 1).item()
        else:  # random policy
            action = torch.tensor([randrange(self.n_action)], dtype=torch.long).item()

        return action

    def learn(self):
        if self.memory.size() < BATCH_SIZE:
            return
        idxs, transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action).unsqueeze(1)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.model_policy(state_batch).gather(1, action_batch)

        next_action_batch = torch.unsqueeze(self.model_policy(next_state_batch).max(1)[1], 1)
        next_state_values = self.model_target(next_state_batch).gather(1, next_action_batch)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch.unsqueeze(1)

        td_errors = (state_action_values - expected_state_action_values).detach().squeeze().tolist()
        self.memory.update(idxs, td_errors)  # update td error
        loss = F.mse_loss(state_action_values, expected_state_action_values)

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.model_policy.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

    def update_target_network(self):
        self.model_target.load_state_dict(self.model_policy.state_dict())

参考

  1. https://zhuanlan.zhihu.com/p/128176891
  2. https://www.cnblogs.com/jiangxinyang/p/10112381.html
### DDQN中的优先经验回放机制 在双深度Q网络(DDQN)中引入优先经验回放(Prioritized Experience Replay, PER),旨在解决标准经验回放缓冲区随机采样的低效问题。PER通过赋予重要样本更高的被抽样概率来加速学习过程并提高数据利用效率[^1]。 #### 实现细节 为了实现这一目标,在存储每个转换$(s_t,a_t,r_{t+1},s_{t+1})$到记忆库的同时,还需记录其对应的TD误差$\delta_i=|y_i-Q(s_i,a_i;\theta)|$作为初始优先级$p_i=\delta_i+\epsilon$($\epsilon>0$防止零优先)[^2]。当从缓冲池抽取批次用于训练时,则依据这些优先级按比例选取样本;具体而言,采用带权重的概率分布$p_i^\alpha/\sum_kp_k^\alpha(\alpha≥0)$决定各条目入选几率大小,其中超参数α控制着重视程度——越高的α意味着更倾向于挑选具有较大TD误差的经验片段[^3]。 然而直接应用上述方案可能导致某些特定事件过度曝光而其他经历得不到充分探索的机会。为此引入重要性加权(importance sampling weights,$w_i=(N*p_i)^{-β}/max(w_j))$调整损失函数计算方式以缓解偏差影响[$^{4}$]。随着迭代次数增加逐渐增大β值直至接近于1使得最终模型能够兼顾全局最优解搜索而不局限于局部极值附近徘徊不前。 ```python import numpy as np from collections import deque class PrioritizedReplayBuffer(object): def __init__(self, capacity, alpha=0.6, beta_start=0.4, beta_frames=1000): self.capacity = int(capacity) self.alpha = float(alpha) self.beta_start = float(beta_start) self.beta_frames = int(beta_frames) self.buffer = [] self.priorities = np.zeros((capacity,), dtype=np.float32) def update_priorities(self, indices, priorities): for idx, priority in zip(indices, priorities): self.priorities[idx] = (abs(float(priority)) + 1e-8)**self.alpha def sample_batch(self, batch_size, frame_idx): N = len(self.buffer) beta = min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start)/self.beta_frames) probabilities = self.priorities[:len(self)] / \ sum(self.priorities[:len(self)]) indices = np.random.choice(len(self), size=batch_size, p=probabilities) samples = [self.buffer[i] for i in indices] importance_sampling_weights = ((N*probabilities[indices])**(-beta)) max_weight = importance_sampling_weights.max() importance_sampling_weights /= max_weight return samples, indices, list(importance_sampling_weights) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值