上一篇文章简单讲了怎么断点续训,但是有时候吧,想要看看训练过程中那些参数都长啥样,从开始到结束各个参数如何变化,在网络训练过程中只有accuracy和loss可以看到,这时候需要使用参数提取的方法。
在上一篇基础上,加入一些新内容
import tensorflow as tf
import os
from tensorflow import keras
import numpy as np
# 设置打印格式,threshold=超过多少不打印,np.inf无限大
np.set_printoptions(threshold=np.inf)
mnist = keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255,x_test/255
model = keras.models.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(128,activation='relu'),
keras.layers.Dense(10,activation='softmax')
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = './临时文件/mnist.ckpt'
# 生成ckpt的同时会生成索引文件
if os.path.exists(checkpoint_save_path+'.index'):
print('-----------load the model----------------')
model.load_weights(checkpoint_save_path)
cp_call

该博客介绍了如何在TensorFlow中实现训练过程中的参数提取和保存。通过使用ModelCheckpoint回调函数,不仅可以进行断点续训,还能观察并记录模型参数的变化。代码示例展示了加载MNIST数据集,构建模型,训练并保存最佳权重,最后将模型的训练变量及它们的值保存到文本文件中,便于后续分析和检查。
最低0.47元/天 解锁文章
1103

被折叠的 条评论
为什么被折叠?



