问题来源:
0)模型训练好了(tensorflow训练),想进行离线的测试
1)我想将多个模型的计算封装成各自的函数接口;
2)同时为了提升计算速度,将模型的session对象设置成了全局变量(方便快捷的处理每次请求);
3)导致一个graph中出现了两个模型;
问题解释:tensorflow默认为进程设置一个默认的graph, 一个graph只能存在一个session,结果就是两个模型在一起出现冲突。
解决办法:为每个模型设置一个graph,但是每个graph必须要事先加载网络模型结构
具体的函数封装:
XX_Model_caculate.py (为例)
#网络结构参数:
n_input = 3
n_steps = 180
n_hidden = 512
...
#创建一个graph变量:
g1 = tf.Graph()
#创建一个session变量:需要指定Graph
sess1=th.Session(graph=g1)
#定义一个model_path
Model1_path="XX/XX" #即自己保存模型的目录(不用具体到文件)
#加载网络结构和参数:
with sess1.as_default():
with g1.as_default():
#创建网络结构( 即 把训练时的网络结构定义拿过来 )
a)定义输入
b)定义网络结构
c) 定义预测的输出
然后:
ft.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_cpt = tf.train.get_checkpoint_statle(Model1_path)
model_saver.restore(sess1,model_cpt.model_checkpoint_path)
# 定义成函数接口
def Model1_caculate(data=None):
x_data=np.asarray(data)
x_data=np.reshape(data,[x,x,x]) #表示reshape成自己模型的输入情况
temp_y = np.ones([1,n_classed]) #即伪造一个y用作sess1.run()的输入
with sess1.as_default(): #因为sess1设置成全局变量,所有可以直接使用,也就节省了加载时间
with sess1.graph.as_default():
result=sess1.run(预测变量名,feed_dict={x=x_data, y=temp_y })
return result
调用方法:(在另一个文件中调用)
import XX_Model_caculate if __name__ == '__main__':
test_data #先获取到源数据
get_result=XX_Model_caculate.Model1_caculate(data=test_data) #计算得到结果
问题得到解决。
另一种解决问题的办法:
tensorflow是静态图;(最新的tensorflow是动态图,不知道如何)
如果是pytorch,就没那么麻烦了,模型随时定义,网络随机加载,参数随时使用。(前提是得再pytorch上训练好模型)