加载预训练模型需要具备两个条件:1.框架结构(知道每一层的名字),2. 预训练好的模型文件.ckpt
加载预训练模型代码如下:
import tensorflow
as tf
import numpy
as np
weights_1 = tf.Variable(tf.zeros([3,4]))
# weights_2 = tf.Variable(tf.zeros([4,3]))
sess = tf.InteractiveSession()
saver = tf.train.Saver()
saver.restore(sess, '/tmp/checkpoint/model.ckpt')
o_test = np.array([[4.0,
3.0,
2.0]], dtype=
'float32')
label = tf.matmul(o_test, weights_1)
# label = tf.matmul(label, weights_2)
print sess.run(label)
以上为加载模型代码,可以写全变量名,也可只写一部分。可根据输出来定。
其中,model.ckpt训练模型代码如下:
import tensorflow
as tf
import numpy
as np
i_data = np.array([[5.0,
3.0,
2.0]], dtype
= 'float32')
i_label= np.array([[15.0,
10.0,
22.0]], dtype
= 'float32')
weights_1 = tf.Variable(tf.zeros([3,4]))
out_1 = tf.matmul(i_data, weights_1)
weights_2 = tf.Variable(tf.zeros([4,3]))
out = tf.matmul(out_1, weights_2)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
loss = tf.reduce_mean(tf.square(out
- i_label))
training = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
sess = tf.Session()
sess.run(init_op)
for i
in range(20000):
sess.run(training)
save_path = saver.save(sess,
'/tmp/checkpoint/model.ckpt')