经验回放的实现
- 使用
pd.DataFrame()
函数定义一个存储经验的表格 - 存储函数
store()
- 取出一个或多个样本的函数
sample()
第一步,在DQNplayer
类中,我们需要初始化一个存储结构
pd.DataFrame(index=range(capacity),
columns=['observation','action','reward',
'next_observation','done']
这个表格的行号是range(capacity)
,即表格中能够存储的经验数量;列号包括一个经验样本
(
S
i
,
A
i
,
R
i
,
S
i
′
)
(S_i, A_i, R_i, S'_i)
(Si,Ai,Ri,Si′) 以及回合是否结束的标志。
那么,整个初始化函数就为:
def __init__(self, capacity):
self.memory = pd.DataFrame(index=range(capacity),
columns=['observation','action','reward',
'next_observation','done')
self.i = 0 # 当前应该将新的经验存储到i行
self.count = 0 # 存储量
self.capacity = capacity # 存储容量
第二步,store()
函数的实现:
def store(self, *args):
self.memory.loc[self.i] = args
self.i = (self.i + 1)%self.capacity
self.count = min(self.count+1, self.capacity)
第三步,取出样本sample()
:
def sample(self, size):
# 选择抽取size条数据,indices存储数据行号组成的数列
indices = np.random.choice(self.count, size=size)
return (np.stack(self.memory.loc[indices, field]) for field in
self.memory.columns)
使用方法
# 初始化,存储容量为replayer_capacity
self.replayer = DQNplayer(replayer_capacity)
# 存储数据
self.replayer.store(observation, action, reward, next_observation, done)
# 提取样本
observations, actions, rewards, next_observations, dones = \
self.replayer.sample(size)