tensorflow本地导入图片进行手写数字识别
from PIL import Image #导入图片模块 import numpy as np #导入科学计算库 import tensorflow as tf #导入机器学习框架 model_save_path = './checkpoint/mnist.ckpt' #训练好的模型路径 model = tf.keras.models.Sequential([ #实现前向传播 tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation = "relu"), tf.keras.layers.Dense(10, activation = "softmax") ]) model.load_weights(model_save_path) #加载参数 preNum = int(input("输入要测试多少张图片:")) for i in range(preNum): #读入待识别的图片 image_path = input("测试图片的数字是:") img = Image.open(image_path) img = img.resize((28, 28), Image.ANTIALIAS) #把输入图片resieze成28*28尺寸的图片,转换为灰度图 img_arr = np.array(img.convert('L')) # img_arr = 255 - img_arr #颜色取反 for i in range(28): #让图片变为只有黑色和白色的高对比图 for j in range(28): if img_arr[i][j] < 200: img_arr[i][j] = 255 #像素点小于200的全部变为255,纯白色 else: img_arr[i][j] = 0 #像素点大于200的全部变为0,纯黑色 img_arr = img_arr / 255.0 #归一化处理 x_predict = img_arr[tf.newaxis, ...] #添加一个维度,变为三个维度 result = model.predict(x_predict) #预测结果 pred = tf.argmax(result, axis = 1) print("\n") tf.print(pred)
图片存放在根目录下 :
结果为:
E:\Anaconda3\envs\TF2\python.exe C:/Users/Administrator/PycharmProjects/untitled8/导入图片手写识别.py
输入要测试多少张图片:3
测试图片的数字是:4.png[4]
测试图片的数字是:7.png
[7]
测试图片的数字是:3.png
[3]
Process finished with exit code 0