【FlappyBird小游戏】编写AI逻辑(二)——基于队列的经验重放池

本文介绍了三种实现深度强化学习中经验回放池(Replay Buffer)的方法,包括基于Numpy数组、Python数组和队列。经验回放池是Deep RL算法的重要组成部分,用于存储和采样过渡状态以进行训练。文中提供了简洁的代码示例,便于理解并应用于实际项目。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文隶属于一个完整小项目,建议读者按照顺序阅读。

本文仅仅展示最关键的代码部分,并不会列举所有代码细节,相信具备RL基础的同学理解起来没有困难。

全部的AI代码可以在【Python小游戏】用AI玩Python小游戏FlappyBird【源码】中找到开源地址。

如果本文对您有帮助,欢迎点赞支持!


文章目录

前言

第1种设计方式:基于Numpy数组

第2种设计方式:基于Python数组

第3种设计方式:基于队列


前言

书写经验重放池是Deep Rl算法的必备技术之一,常见的是基于数组的形式,本文列举3种常见的实现方式

本文不会详细介绍代码,因为太过简单,不理解的同学可以直接在评论区提问。


第1种设计方式:基于Numpy数组

class ReplayBuffer(object):
    def __init__(self, capacity,state_dims):
        self.capacity = capacity # 经验池容量大小
        self.data = np.zeros((capacity, state_dims* 2+2))  # 经验池存放的经验数据
        self.pointer = 0 # 当前指针
    def store_transition(self, s, a, r, s_):
        # 检查是否存在
        if not hasattr(self, 'pointer'):
            self.pointer = 0
        # 存储数据
        transition = np.hstack((s, [a,r], s_))  # 按行连接
        index = self.pointer % self.capacity  # 如果超过该容量则自动从头开始
        self.data[index, :] = transition
        self.pointer += 1
    def sample(self, batch_size):
        if self.capacity < self.pointer:
            batch_indexs = np.random.choice(self.capacity, size=batch_size)
        else:
            batch_indexs = np.random.choice(self.pointer, size=batch_size)
            #assert (self.pointer >= self.capacity, '经验回放池还没有被装满')
            #print('经验回放池还没有被装满就开始采样')
        return self.data[batch_indexs, :]  # 获取n个采样
    

第2种设计方式:基于Python数组

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = int((self.position + 1) % self.capacity)  # as a ring buffer

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))  # stack for each element
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

第3种设计方式:基于队列

本项目使用队列来进行设计,其代码更加简洁:

from collections import deque
import random
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.memory_size = capacity # 容量大小
        self.num = 0 # 存放的经验数据数量
        self.data = deque() # 存放经验数据的队列

    def store_transition(self, state,action,reward,state_,terminal):
        self.data.append((state, action, reward, state_, terminal))# 添加数据
        if len(self.data) > self.memory_size:
            self.data.popleft()
            self.num -= 1
        self.num += 1

    def sample(self, batch_size):
        minibatch = random.sample(self.data, batch_size)
        return minibatch  # 获取n个采样

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔法攻城狮MRL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值