背景:
训练好一个网络后,欲修改某个网络层的权值size或者结构,重新初始化该网络层,保留其他网络层参数,继续训练。
方法:
此处需要用到slim包中 slim.get_variables_to_restore 函数,由于slim包属于models模块,已在tensorflow1.0后被取消。因此需要重新下载:下载路径:https://github.com/tensorflow/models
步骤:
1.下载slim:
cd ../models/research/slim
git clone https://github.com/tensorflow/models.git
2.引用slim:
sys.path.append("./models/research/slim")
import tensorflow.contrib.slim as slim
3.restore网络模型:
ckpt = tf.train.get_checkpoint_state(CHECK_POINT_PATH) #判断模型是否存在。
if ckpt and ckpt.model_checkpoint_path:
exclude = ['input', 'layer_1'] #忽略模型中这两个层
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
all_parameters_saver = tf.train.Saver(variables_to_restore)
all_parameters_saver.restore(sess, ckpt_path)
其中,CHECK