加载预训练模型
https://cloud.tencent.com/developer/article/1197031
加载部分预训练模型
variables = tf.contrib.framework.get_variables_to_restore()
variables_to_restore = [v for v in variables if v.name.split('/')[0] == 'easyflow']
saver_res = tf.train.Saver(variables_to_restore)
saver_res.restore(sess, pre_train_model)
固定部分网络参数
variables = slim.get_variables_to_restore()
variables_to_retore = [v for v in variables if v.name.split('/')[0] == 'easyflow']
variables_to_train = [v.name for v in variables if v.name.split('/')[0] == 'netflow']
print(variables_to_retore)
### Defind optimizer
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
loss_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=variables_to_train)
with tf.control_dependencies(update_ops):
Training_step2 = tf.compat.v1.train.AdamOptimizer(lr_ori).minimize(OptimizeLoss_2, var_list=loss_vars)
saver_res = tf.compat.v1.train.Saver(variables_to_retore)
saver_res.restore(sess, pre_train_model)
varvar = sess.graph.get_tensor_by_name('easyflow/c1/weights:0')
张量操作:
tensor张量的打印,
tf.slice(input,begin,size),
tf.split(input,num_or_size_split,axis=0,num=None)
tf.concat(input,axis)
tf.stack(input,axis=0)
tf.unstack(input,num=None,axis=0)
http://chenjingjiu.cn/index.php/2019/07/05/tensorflow-matrix-op/
tf.transpose()张量维度转置
https://blog.youkuaiyun.com/qq_40994943/article/details/85270159
tf.split()把一个张量划分成几个子张量
https://blog.youkuaiyun.com/qq_31150463/article/details/84137883