Tensorflow:tf.gradient()用法以及参数stop_gradient理解

本文深入解析了TensorFlow中tf.gradients函数的使用方法,详细介绍了如何通过该函数进行微分计算,包括参数ys、xs及stop_gradients的具体作用。通过实例演示了不同参数设置下梯度计算的变化,帮助读者理解反向传播过程。
部署运行你感兴趣的模型镜像

tf.gradient()

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None
)

ys : 类型是张量或者张量列表,类似于目标函数,需要被微分的函数
xs:类型是张量或者张量列表,需要求微分的对象。(上述即为:dys/dxs)
stop_gradients: 可选参数,类型是张量或者张量列表,不需要通过微分的对象(比较抽象,看完下面的例子)

用一个例子来帮助理解

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a])
with tf.Session() as sess:
    print(sess.run(g))
结果:[3.0, 1.0]

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[b])
with tf.Session() as sess:
    print(sess.run(g))
结果:[1.0, 1.0]  

可以看出,第一个参数ys是准备被微分的函数,第二个参数即xs填的是反向传播是需要求导的参数,第三个参数即stop_gradient,在反向传播时,如果填了参数b,那么a + b中a,b都是独立的,否则a + b= 3a(因为在本例中b = 2a)


如果觉得我有地方讲的不好的或者有错误的欢迎给我留言,如果对您有帮助,帮我点个赞哦~,感谢大家阅读

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

import numpy as np import random import tensorflow as tf import wx unit = 80 # 一个方格所占像素 maze_height = 4 # 迷宫高度 maze_width = 4 # 迷宫宽度 class Maze(wx.Frame): def __init__(self, parent): # +16和+39为了适配客户端大小 super(Maze, self).__init__(parent, title='maze', size=(maze_width*unit+16, maze_height*unit+39)) self.actions = ['up', 'down', 'left', 'right'] self.n_actions = len(self.actions) # 按照此元组绘制坐标 self.coordinate = (0, 0) self.rl = DeepQNetwork(4, 2) self.generator = self.rl.RL_Q_network() # 使用EVT_TIMER事件和timer类可以实现间隔多长时间触发事件 self.timer = wx.Timer(self) # 创建定时器 self.timer.Start(5) # 设定时间间隔 self.Bind(wx.EVT_TIMER, self.build_maze, self.timer) # 绑定一个定时器事件 self.Show(True) def build_maze(self, event): # yield在生成器运行结束后再次调用会产生StopIteration异常, # 使用try_except语句避免出现异常并在异常出现(程序运行结束)时关闭timer try: self.generator.send(None) # 调用生成器更新位置 except Exception: self.timer.Stop() self.coordinate = self.rl.status dc = wx.ClientDC(self) self.draw_maze(dc) def draw_maze(self, dc): dc.SetBackground(wx.Brush('white')) dc.Clear() for row in range(0, maze_height*unit+1, unit): x0, y0, x1, y1 = 0, row, maze_height*unit, row dc.DrawLine(x0, y0, x1, y1) for col in range(0, maze_width*unit+1, unit): x0, y0, x1, y1 = col, 0, col, maze_width*unit dc.DrawLine(x0, y0, x1, y1) dc.SetBrush(wx.Brush('black')) dc.DrawRectangle(2*unit+10, unit+10, 60, 60) dc.SetBrush(wx.Brush('yellow')) dc.DrawRectangle(2*unit+10, 2*unit+10, 60, 60) dc.SetBrush(wx.Brush('red')) dc.DrawCircle((self.coordinate[0]+0.5)*unit, (self.coordinate[1]+0.5)*unit, 30) class DeepQNetwork(object): def __init__(self, n_actions, n_features, # 状态的属性个数(2,横坐标和纵坐标) learning_rate=0.01, reward_decay=0.9, # gamma epsilon_greedy=0.9, # epsilon replace_target_iter=300, # 更新target网络的间隔步数 buffer_size=500, # 样本缓冲区 batch_size=32, ): self.n_actions = n_actions self.n_features = n_features self.lr = learning_rate self.gamma = reward_decay self.epsilon_max = epsilon_greedy self.replace_target_iter = replace_target_iter self.buffer_size = buffer_size self.buffer_counter = 0 # 统计目前进入过buffer的数量 self.batch_size = batch_size self.epsilon = epsilon_greedy self.max_episode = 300 self.status = (0, 0) # 用于记录在运行过程中的当前位置,然后提供给Maze对象 self.learn_step_counter = 0 # 学习计步器 self.buffer = np.zeros((self.buffer_size, n_features*2+2)) # 初始化Experience buffer[s,a,r,s_] self.build_net() # 将eval网络中参数全部更新到target网络 target_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net') eval_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope('soft_replacement'): self.target_replace_op = [tf.assign(t, e) for t, e in zip(target_params, eval_params)] self.sess = tf.Session() tf.summary.FileWriter('logs/', self.sess.graph) self.sess.run(tf.global_variables_initializer()) def build_net(self): self.s = tf.placeholder(tf.float32, [None, self.n_features]) self.s_ = tf.placeholder(tf.float32, [None, self.n_features]) self.r = tf.placeholder(tf.float32, [None, ]) self.a = tf.placeholder(tf.int32, [None, ]) w_initializer = tf.random_normal_initializer(0., 0.3) b_initializer = tf.constant_initializer(0.1) # q_eval网络架构,输入状态属性,输出4种动作 with tf.variable_scope('eval_net'): eval_layer = tf.layers.dense(self.s, 20, tf.nn.relu, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='eval_layer') self.q_eval = tf.layers.dense(eval_layer, self.n_actions, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='output_layer1') with tf.variable_scope('target_net'): target_layer = tf.layers.dense(self.s_, 20, tf.nn.relu, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='target_layer') self.q_next = tf.layers.dense(target_layer, self.n_actions, kernel_initializer=w_initializer, bias_initializer=b_initializer, name='output_layer2') with tf.variable_scope('q_target'): # 计算期望价值,并使用stop_gradient函数将其不计算梯度,也就是当做常数对待 self.q_target = tf.stop_gradient(self.r + self.gamma * tf.reduce_max(self.q_next, axis=1)) with tf.variable_scope('q_eval'): # 将a的值对应起来, a_indices = tf.stack([tf.range(tf.shape(self.a)[0]), self.a], axis=1) self.q_eval_a = tf.gather_nd(params=self.q_eval, indices=a_indices) with tf.variable_scope('loss'): self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_a)) with tf.variable_scope('train'): self.train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) # 存储训练数据 def store_transition(self, s, a, r, s_): transition = np.hstack((s, a, r, s_)) index = self.buffer_counter % self.buffer_size self.buffer[index, :] = transition self.buffer_counter += 1 def choose_action_by_epsilon_greedy(self, status): status = status[np.newaxis, :] if random.random() < self.epsilon: actions_value = self.sess.run(self.q_eval, feed_dict={self.s: status}) action = np.argmax(actions_value) else: action = np.random.randint(0, self.n_actions) return action def learn(self): # 每学习self.replace_target_iter步,更新target网络的参数 if self.learn_step_counter % self.replace_target_iter == 0: self.sess.run(self.target_replace_op) # 从Experience buffer中选择样本 sample_index = np.random.choice(min(self.buffer_counter, self.buffer_size), size=self.batch_size) batch_buffer = self.buffer[sample_index, :] _, cost = self.sess.run([self.train_op, self.loss], feed_dict={ self.s: batch_buffer[:, :self.n_features], self.a: batch_buffer[:, self.n_features], self.r: batch_buffer[:, self.n_features+1], self.s_: batch_buffer[:, -self.n_features:] }) self.learn_step_counter += 1 return cost def get_environment_feedback(self, s, action_name): is_terminal = False if action_name == 0: # up if s == (2, 3): r = 1 is_terminal = True else: r = 0 s_ = (s[0], np.clip(s[1]-1, 0, 3)) elif action_name == 1: # down if s == (2, 0): r = -1 is_terminal = True else: r = 0 s_ = (s[0], np.clip(s[1]+1, 0, 3)) elif action_name == 2: # left if s == (3, 1): r = -1 is_terminal = True elif s == (3, 2): r = 1 is_terminal = True else: r = 0 s_ = (np.clip(s[0]-1, 0, 3), s[1]) else: # right if s == (1, 1): r = -1 is_terminal = True elif s == (1, 2): r = 1 is_terminal = True else: r = 0 s_ = (np.clip(s[0]+1, 0, 3), s[1]) return r, s_, is_terminal def RL_Q_network(self): # 使用yield函数实现同步绘图 for episode in range(self.max_episode): s = (0, 0) self.status = s is_terminal = False yield step = 0 while is_terminal is False: a = self.choose_action_by_epsilon_greedy(np.array(s)) r, s_, is_terminal = self.get_environment_feedback(s, a) self.store_transition(np.array(s), a, r, np.array(s_)) # 每5步进行一次学习 if step > 100 and step % 5 == 0: cost = self.learn() print('cost: %.3f' % cost) s = s_ self.status = s step += 1 yield if __name__ == '__main__': app = wx.App() Maze(None) app.MainLoop() 帮我分析一下这段代码吧,最好加上注释
最新发布
09-08
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值