CNN训练MNIST数据集tenflow2(下)

这篇博客介绍了使用TensorFlow2训练MNIST数据集的CNN模型,并探讨了模型量化的概念。作者展示了如何加载预训练模型,执行推理,并通过不同位宽的量化模型进行评估,得出精度结果。尽管量化模型的精度接近原始模型,但作者指出数据流量化是低精度计算的关键,并计划进一步研究低比特量化和数据流量化技术。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CNN训练MNIST数据集tenflow2(上)
CNN训练MNIST数据集tenflow2(中)

接下来利用tflite做量化,其实中间的数据流到底有没量化,我也不清楚。。。

一、数据准备

# 1.数据准备
import tensorflow as tf
import numpy as np

mnist = tf.keras.datasets.mnist
img_rows,img_cols = 28,28
(x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()
x_train = x_train_.reshape(x_train_.shape[0],img_rows,img_cols,1)
x_test = x_test_.reshape(x_test_.shape[0],img_rows,img_cols,1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 256
x_test = x_test / 256
y_train_onehot = tf.keras.utils.to_categorical(y_train_)
y_test_onehot = tf.keras.utils.to_categorical(y_test_)

二、初始模型导入

model = tf.keras.Sequential()
model = tf.keras.models.load_model('models/mnist_tf2_fw.h5')
score = model.evaluate(x_test, y_test_onehot, verbose=0)
print('Test accuracy:', "{:.5f}".format(score[1]))

三、量化模型导入

interpreter= tf.lite.Interpreter(model_path="models/tflite_tf2_8.tflite")
#interpreter= tf.lite.Interpreter(model_path="models/tflite_tf2_dy.tflite")
#interpreter= tf.lite.Interpreter(model_path="models/tflite_tf2_16.tflite")
#interpreter= tf.lite.Interpreter(model_path="models/tflite_tf2_32.tflite")

interpreter.allocate_tensors()

input_details=interpreter.get_input_details()
output_details=interpreter.get_output_details()

#print(input_details)
#print(output_details)
y_lebal=y_test_onehot.argmax(1)

# 量化模型推理
pre = []
# tflite针对安卓端运用,每次只能推理一个数据
for i in range (len(x_test)):
    x_test1 = x_test[i].reshape(1,28,28,1).astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'],x_test1)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    pre.append(output_data.argmax())

temp = (pre==y_lebal)
acc=sum(temp)/10000
print(acc)
32bit16bit8bit默认
98.50098.50098.44098.440

这结果似乎和权重量化的差不多。数据流量化确实不太容易,需要更加底层、细粒度的操作。只有rram的训练有可能是4bit和8bit的混合,所以低精度数据流的训练还是挺重要的。

后期计划,做低比特、数据流量化的调用,看看别人的工作,为rram高性能计算打下理论基础。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值