来源:北京大学Tensorflow2.0(B站搜)
直接加载保存的模型,见上一篇博客:TF2.0断点读序(2)
from PIL import Image
import numpy as np
import tensorflow as tf
model_save_path = 'E:/path/tf2.0_data/tf_board/save_one/mnist.ckpt'
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
model.load_weights(model_save_path)
preNum = int(input("input the number of test pictures:"))
for i in range(preNum):
image_path = input("the path of test picture:")
img = Image.open(image_path)
img = img.resize((28, 28), Image.ANTIALIAS)
img_arr = np.array(img.convert('L'))
for i in range(28):
for j in range(28):
if img_arr[i][j] < 200:
img_arr[i][j] = 255
else:
img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
print('\n')
tf.print(pred)

输入要读取的数量及数字图片路径就可以得到预测结果了
本文介绍如何在TensorFlow2.0中加载预训练模型,并使用该模型进行手写数字图片的识别。通过输入图片路径,模型能够预测并输出数字识别结果。涉及的内容包括模型结构定义、权重加载、图片预处理和预测流程。
3783

被折叠的 条评论
为什么被折叠?



