基于tensorflow如何对变化中的学习率实现与之相关的梯度裁剪-VDSR

# NOTE: learning rate decay
        global_step = tf.Variable(0, trainable=False)
        #learning_rate = tf.train.exponential_decay(config.learning_rate, global_step * config.batch_size, len(input_)*100, 0.1, staircase=True)
        learning_rate = tf.train.exponential_decay(config.learning_rate, global_step , 10, 0.1, staircase=True)#自己设就行
        # NOTE: Clip gradient
        #代码修改 在这里我们输入的learning_rate为自适应变化的 对应论文中梯度裁剪的大小也不同
        #opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        #lr  = opt._lr

        grad_and_value = opt.compute_gradients(self.loss)

        clip = tf.Variable(config.clip_grad, name='clip')
        clip = config.clip_grad/learning_rate
        capped_gvs = [(tf.clip_by_value(grad, -(clip), clip), var) for grad, var in grad_and_value]
        #小于-clip的等于-clip 大于的等于clip
        self.train_op = opt.apply_gradients(capped_gvs, global_step=global_step)#将计算出的梯度应用到变量上,是函数minimize()的第二部分
        #self.train_op = tf.train.AdamOptimizer(learning_rate=config.learning_rate).minimize(self.loss)

当时还发现 自己因为是输出对象和输入对象重名出现了问题。对应输出代码如下:

  # Train
        if not config.is_train:
            print("Now Start Training...")
            for ep in range(config.epoch):
                # Run by batch images
                batch_idxs = len(input_) // config.batch_size
                #for idx in range(0, batch_idxs):
                for idx in range(0, 4):
                    #print(lr)
                    batch_images = input_[idx * config.batch_size : (idx + 1) * config.batch_size]
                    batch_labels = label_[idx * config.batch_size : (idx + 1) * config.batch_size]
                    counter += 1
                    #print(self.sess.run(lr))
                    _, err,clipduan = self.sess.run([self.train_op, self.loss,clip], feed_dict={self.images: batch_images, self.labels: batch_labels})
                    ####注意这个地方 前面的名字和后面对应的名字不能一样 否则会触发报错
                    if counter % 3 == 0:
                        print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" % ((ep+1), counter, time.time()-time_, err ))
                        print(clipduan)
                    #if counter % 500 == 0:
                        self.save(config.checkpoint_dir, counter)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值