多线程使用keras训练模型错误-"is not an element of this graph"

本文解决了一个在使用TensorFlow进行多线程模型训练时遇到的TypeError错误,详细分析了错误产生的原因,即在不同线程中变量定义的计算图与会话不匹配,并提供了代码修正方案,确保每个线程中的计算图和会话一致。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

错误场景

为了每隔固定时间训练一次模型, luffy在线程函数中设置timer再次调用线程函数, 简化版代码如下

def _thread_func(interval=10):
	model = train_model()				
	timer = threading.Timer(interval, _thread_func, args=(interval))		#设置定时器间隔interval后再次调用_thread_func, 无限循环	
	timer.start()
			
def train():
	x_train, y_train = get_dataset()		#获取训练集
	
	model = build_model()					#构建模型的输入输出和中间层, 不详细展开
	model.compile("adam", loss="mse")
	model.fit(x_train, y_train, batch_size=32, epochs=10)
	
	return model

错误内容

TypeError: Can not interpret feed_dict key as Tensor: Tensor Tensor("func:0", shape=(?,?), dtype=int32) is not an element of this graph.

原因分析

 查找资料后发现, keras是基于tensorflow的(luffy用的backend是tensorflow), 而tensorflow有两个关键的对象, 计算图(Graph)和会话(Session). 以luffy的智商进行了理解就是: Graph和Session两者通常是一一对应的.
 放到luffy碰到的问题中:

  • _thread_func1第一次执行时, 变量是定义在默认的计算图上的, 我们称之为default_graph,第一次训练, 也是运行在默认会话上的, 我们称之为default_session.
  • 隔了一段时间后, _thread_func1调用了自己, 新开辟了一个线程_thread_func2, 此时在_thread_func2也调用了train()函数, 但是此时变量是定义在新的graph上的, 我们称之为new_graph, 但是进程并没有切换新的session, 用的还是default_session, 所以就出错啦.

代码修正

非常简单, 加两行代码即可, 保证计算图和会话一致.

def train():
	with tf.Graph().as_default():				#第一行增加代码
		x_train, y_train = get_dataset()		#获取训练集
		
		model = build_model()						#构建模型的输入输出和中间层, 不详细展开
		model.compile("adam", loss="mse")
		with tf.Session() as sess:				#第二行增加代码
			model.fit(x_train, y_train, batch_size=32, epochs=10)
		
		return model
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值