在github上分析了一个FlappyBird的DQN项目,把项目逻辑记录下来,加深DQN的理解。
PS: 跟踪到原始的github项目,可以找到几篇很经典的文献,有助于加深理解
图中矩形框标识数据内容,关键的几个模块
1. Environment
算法目标环境,此处就是flappy bird游戏,这个模块接受算法给出的action,返回reward。这个reward可以看作一个短期收益,根据游戏规则对action导致的结果给出收益。比如导致游戏失败,返回-1,否则返回0.1。
2. Memory
记忆库,像cache一样,记录历史记录,每个记录是四元组(当前状态,动作,收益,下一个状态),其中下一状态是指当前状态下,动作生效后进入的状态。
3. TargetNet
这个网络本身是不进行训练的,但是每T轮后,会从另一个模块QNet复制网络参数。这个网络接受当前状态,生成所有action的未来期望收益,游戏就是选择期望最大的那个action,施加到environment上。
4. QNet
和TargetNet网络结构一样,每间隔N轮后,从memory中随机抽取若干历史记录,进行训练,更新参数。T轮后,QNet会把参数复制TargetNet,因此TargetNet在T轮内网络参数是不变的,保证一段时间内算法输出是“稳定的”。 训练中,输入的是当前状态,QNet会预测在”current_state”下,所有action的期望收益,抽取”current_state”对应的”action”的收益,和”current_state”对应的”reward”计算loss。
PS:根据代码,从memory中随机抽取的历史数据,送入QNet做训练前,会把记录中”next_state”送入targetNet,利用ta