错误场景
为了每隔固定时间训练一次模型, luffy在线程函数中设置timer再次调用线程函数, 简化版代码如下
def _thread_func(interval=10):
model = train_model()
timer = threading.Timer(interval, _thread_func, args=(interval)) #设置定时器间隔interval后再次调用_thread_func, 无限循环
timer.start()
def train():
x_train, y_train = get_dataset() #获取训练集
model = build_model() #构建模型的输入输出和中间层, 不详细展开
model.compile("adam", loss="mse")
model.fit(x_train, y_train, batch_size=32, epochs=10)
return model
错误内容
TypeError: Can not interpret feed_dict key as Tensor: Tensor Tensor("func:0", shape=(?,?), dtype=int32) is not an element of this graph.
原因分析
查找资料后发现, keras是基于tensorflow的(luffy用的backend是tensorflow), 而tensorflow有两个关键的对象, 计算图(Graph)和会话(Session). 以luffy的智商进行了理解就是: Graph和Session两者通常是一一对应的.
放到luffy碰到的问题中:
- _thread_func1第一次执行时, 变量是定义在默认的计算图上的, 我们称之为default_graph,第一次训练, 也是运行在默认会话上的, 我们称之为default_session.
- 隔了一段时间后, _thread_func1调用了自己, 新开辟了一个线程_thread_func2, 此时在_thread_func2也调用了train()函数, 但是此时变量是定义在新的graph上的, 我们称之为new_graph, 但是进程并没有切换新的session, 用的还是default_session, 所以就出错啦.
代码修正
非常简单, 加两行代码即可, 保证计算图和会话一致.
def train():
with tf.Graph().as_default(): #第一行增加代码
x_train, y_train = get_dataset() #获取训练集
model = build_model() #构建模型的输入输出和中间层, 不详细展开
model.compile("adam", loss="mse")
with tf.Session() as sess: #第二行增加代码
model.fit(x_train, y_train, batch_size=32, epochs=10)
return model