【深度学习】DQN网络代码详解

前面的话

代码作者是 莫烦 大佬。在github上可以找到代码原文。同时在视频网站上可以找到 莫烦 大佬的系列教学视频。
这里我用代码注释的形式标注了所有代码的含义,同时还有部分函数方法的简单用法。各位可以看情况细细研究或者知其大概。

代码部分

"""
This part of code is the DQN brain, which is a brain of the agent.
All decisions are made in here.
Using Tensorflow to build the neural network.

View more on my tutorial page: https://morvanzhou.github.io/tutorials/

Using:
Tensorflow: 1.0
gym: 0.7.3
"""

import numpy as np
import pandas as pd
import tensorflow as tf

# 设置随机数seed
np.random.seed(1)
tf.set_random_seed(1)


# Deep Q Network off-policy
class DeepQNetwork:
    def __init__(
            self,
            n_actions,
            n_features,
            learning_rate=0.01,
            reward_decay=0.9,
            e_greedy=0.9,
            replace_target_iter=300,
            memory_size=500,
            batch_size=32,
            e_greedy_increment=None,
            output_graph=False,
    ):
        self.n_actions = n_actions  # action num
        self.n_features = n_features  # state num
        self.lr = learning_rate  # 学习率
        self.gamma = reward_decay  # 折扣因子
        self.epsilon_max = e_greedy  # 贪婪决策概率
        self.replace_target_iter = replace_target_iter  # target和eval的参数更新间隔步
        self.memory_size = memory_size  # 记忆库大小
        self.batch_size = batch_size  # 批量大小
        self.epsilon_increment = e_greedy_increment  # greedy变化
        self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max

        # total learning step
        self.learn_step_counter = 0  # 计步器

        # initialize zero memory [s, a, r, s_]
        self.memory = np.zeros((self.memory_size,
                                n_features * 2 + 2))  # 初始化记忆库

        # consist of [target_net, evaluate_net]
        self._build_net()
        t_params = tf.get_collection('target_net_params')
        e_params = tf.get_collection('eval_net_params')
        #zip() 函数用于将可迭代对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象。
        self.replace_target_op = [
            tf.assign(t, e) for t, e in zip(t_params, e_params)
        ]

        self.sess = tf.Session()

        if output_graph:
            # $ tensorboard --logdir=logs
            # tf.train.SummaryWriter soon be deprecated, use following
            tf.summary.FileWriter("E:/Code/logs", self.sess.graph)

        self.sess.run(tf.global_variables_initializer())
        self.cost_his =
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值