网络架构如下:
参数为a1,a2,b1,b2,
网络输出:y=a1*a2*x+b1+b2
目标函数:y=x
一、网络局部参数restore
应用场景:网络架构修改,但是部分参数需要重新利用;
设置方法:将var_list参数传给tf.train.Saver即可只save/restore var_list里的参数
如何使用:
(1)保存save:a1,a2,b1,b2分别为10,20,30,40
import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
a1=tf.Variable(tf.constant([10],dtype=tf.float32),name="A1")
b1=tf.Variable(tf.constant([30],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
a2=tf.Variable(tf.constant([20],dtype=tf.float32),name="A2")
b2=tf.Variable(tf.constant([40],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2')
saver=tf.train.Saver()
saver.save(sess,"./10203040")
(2)恢复restore:
import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1")
b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2")
b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2')
saver=tf.train.Saver(var_list_ab1)
saver.restore(sess,"./10203040")
print(sess.run([a1,a2,b1,b2]))
二、只训练局部参数,如只训练a1,b1;而a2,b2保持不变
设置方法:在optimizer传入要训练的参数列表,即var_list参数:
方法一(基于variable_scope获取var_list):
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list)
方法二(比较笨拙的方式):
import tensorflow as tf
import random
#目标函数y=x
#也就是网络收敛时:a1*a2=1,b1+b2=0
x=tf.placeholder(tf.float32,[1])
with tf.variable_scope("AB1"):
a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1")
b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1")
with tf.variable_scope("AB2"):
a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2")
b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2")
y=a1*a2*x+b1+b2
#
_y=tf.placeholder(tf.float32,[1])
loss=tf.square(y-_y)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
#var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1')
var_list = [a1,b1] # !!!!!!!!!在这里,手工添加var_list
train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list)
while True:
input=[random.randint(0,100)*0.01] #不乘以0.0001,则网络无法收敛
label=[input[0]]
_,a1v,a2v,b1v,b2v,lossv=sess.run([train,a1,a2,b1,b2,loss],feed_dict={x:input,_y:label})
if (lossv<1e-10):
break
print("train data=%s %s" %(input,label))
print("a=%s %s\n b=%s %s\n loss=%s" %(a1v,a2v,b1v,b2v,lossv))