强化学习与TensorFlow模型的训练部署
1. 强化学习训练设置
在强化学习训练中,我们在每个训练步骤会采样64个轨迹的批次,每个轨迹包含2个步骤(即2步构成1次完整转移,包括下一步的观测)。数据集会并行处理3个元素,并预取3个批次。
对于策略梯度等在线策略算法,每个经验应只采样一次,用于训练后就丢弃。此时仍可使用回放缓冲区,不过不使用数据集,而是在每次训练迭代时调用回放缓冲区的 gather_all() 方法获取包含所有已记录轨迹的张量,用于训练步骤,最后调用 clear() 方法清空回放缓冲区。
1.1 创建训练循环
为加速训练,我们将主要函数转换为TensorFlow函数,使用 tf_agents.utils.common.function() 来包装 tf.function() :
from tf_agents.utils.common import function
collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)
创建一个运行主训练循环的函数:
def train_agent(n_iterations):
time_step = None
policy_state = agent.
超级会员免费看
订阅专栏 解锁全文
1931

被折叠的 条评论
为什么被折叠?



