tf.Saver() 的使用心得

本文分享了在TensorFlow中使用tf.Saver()进行模型保存和恢复的心得,特别是在fine-tune预训练模型和分阶段训练时遇到的问题及解决方案。问题包括:1) fine-tune时如何处理Adam优化器的变量;2) 分阶段训练时如何处理变量的trainable属性变化。文中提供了两种解决方法,并强调了变量初始化和restore的顺序与选择。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自己训练神经网络时,可能需要:1. fine-tune 预训练好的模型; 2. 分阶段训练。 这时tf.Saver()的使用必不可少。以下是自己使用这个函数的一些心得。

示例代码:

import tensorflow as tf
import numpy as np

def add_layer(x, kin, kout, name, trainable=True):
    W = tf.Variable(tf.random_normal([kin, kout])*(2/kin)**0.5, name='W'+name, trainable=trainable)
    b = tf.Variable(tf.ones([kout])*0.1, name='b'+name, trainable=trainable)
    return tf.nn.relu(tf.matmul(x, W) + b)

x = np.linspace(0, 1)[:, np.newaxis].astype(np.float32)
y = x**2 + 0.1

net = add_layer(x, 1, 10, '1')
net = add_layer(net, 10, 100, '2', trainable=False)
net = add_layer(net, 100, 1, '3')

loss = tf.reduce_mean(tf.square(net - y))
train = tf.train.AdamOptimizer(0.001).minimize(loss)

sess = tf.Session()    
#v_list = [v for v in tf.global_variables() if 'Adam' in v.name or 'beta' in v.name]
#sess.run(tf.initialize_variables(v_list))
sess.run(tf.global_variables_initializer())

v_list = [v for v in tf.global_variables() if 'Adam' not in v.name and 'beta' not i
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值