强化学习实践六 :给Agent添加记忆功能

本文介绍了如何在强化学习的Agent中添加记忆功能,通过抽象的Agent基类、状态转换、Episode和Experience的概念,实现个体可以从记忆中批量学习。此外,文章还详细阐述了PuckWorld环境,这是一个连续二维空间中的追逐目标物体的问题,具有连续的观测空间和离散的行动空间。后续将应用这些基础实现DQN算法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在《强化学习》第一部分的实践中,我们主要剖析了gym环境的建模思想,随后设计了一个针对一维离散状态空间的格子世界环境类,在此基础上实现了SARSA和SARSA(λ)算法。《强化学习》第二部分内容聚焦于解决大规模问题,这类问题下的环境的观测空间通常是多维的而且观测的通常是连续变量,或者行为不再是离散的简单行为,而是由可在一定区间内连续取值的变量构成,在解决这类大规模问题时必须要对价值函数(或策略函数)进行一定程度的近似表示。在对这些函数进行近似表示的时候,可以使用多种机器学习算法,其中最常用的是线性回归或(深度)神经网络。从本实践六开始,我将尝试实现一些用于解决大规模问题的强化学习算法。之所以说尝试主要是因为这类算法的调试和训练需要花费大量的精力,限于水平也可能得不到令人满意的结果,所以只能量力而为,有兴趣一起做的朋友可以合力编写代码。

在这部分的实践中,我将主要利用gym里提供的一些经典环境,比如CartPole、MountainCar等。同时,我也编写了公开课提到的一个PuckWorld的环境类,这个环境也很有意思,我也会用它来测试编写的代码。本次实践,我将把之前的两个Agent类抽象成一个基类(Agent),同时针对状态转换、Episode等进行建模以实现Agent可以具备一定的记忆功能以及从可以从记忆里批量学习,最后我将简单介绍一下PuckWorld环境。

 

抽象的Agent基类

为了体现继承和多态性,增加代码的复用性和可读性,我们先把Agent类做一个抽象,基类Agent除具备之前提到的一些执行策略、执行行为、学习等基本功能外,同时还具有记住一定数量的已经经历过的状态转换对象的功能,最后还应能从记忆中随机获取一定数量的状态转换对象以供批量学习的功能,为此,Agent类可以如下设计:

class Agent(object):
    '''Base Class of Agent
    '''
    def __init__(self, env: Env = None, 
                       trans_capacity = 0):
        # 保存一些Agent可以观测到的环境信息以及已经学到的经验
        self.env = env
        self.obs_space = env.observation_space if env is not None else None
        self.action_space = env.action_space if env is not None else None
        self.experience = Experience(capacity = trans_capacity)
        # 有一个变量记录agent当前的state相对来说还是比较方便的。要注意对该变量的维护、更新
        self.state = None   # current observation of an agent

    def performPolicy(self,policy_fun, s):
        if policy_fun is None:
            return self.action_space.sample()
        return policy_fun(s)

    def act(self, a0):
        s0 = self.state
        s1, r1, is_done, info = self.env.step(a0)
        # TODO add extra code here
        trans = Transition(s0, a0, r1, is_done, s1)
        total_reward = self.experience.push(trans)
        self.state = s1
        return s1, r1, is_done, info, total_reward

    def learning(self):
        '''need to be implemented by all subclasses
        '''
        raise NotImplementedError

    def sample(self, batch_size = 64):
        '''随机取样
        '''
        return self.experience.sample(batch_size)

    @property
    def total_trans(self):
        '''得到Experience里记录的总的状态转换数量
        '''
        return self.experience.total_trans

在上面的代码中,Agent类维护了从env对象得来的状态和行为空间对象,同时维护了一个state对象用于记录个体当前的状态(观测),此外多了一个experience对象。该对象表示的即是个体的记忆内容,它将记录个体在一定期限内所经历过的状态和行为等相关信息。让个体记住经历过的事件主要目的是使得个体可以从中随机获取一定数量的相互之间基本没有关联的状态转换信息,这些无关的状态转换信息将使得个体可以学到一个更好的价值函数的近似表示。在我的设计中,经历(Experience)将由一系列有序的Episode组

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值