用 tensorflow 1 时,想要多次重复实验取平均,在两次实验之间需要清一次计算图,否则会报错说 xx 变量重复定义。代码形式:
# import tensorflow
class MyModel:
def __init__(self):
# build model
def train(self):
with tf.Session() as sess:
# training
# 多 runs 取平均
for i_run in range(n_run):
model = MyModel()
model.train()
# 两 runs 之间清一次计算图
tf.reset_default_graph()
print("DONE")
本文探讨了如何在使用TensorFlow 1.x进行模型训练时,避免因重复定义变量导致的错误,并介绍如何在每次实验间清空计算图以实现多轮实验平均。通过实例演示了如何在`MyModel`类中正确设置和清理计算图,确保代码的正确执行。
831

被折叠的 条评论
为什么被折叠?



