项目地址:https://github.com/LionJackson/imageClassification
Flutter项目地址:https://github.com/LionJackson/flutter_image
一. 模型转换
编写tflite模型工具类:
import os
import PIL
import tensorflow as tf
import keras
import numpy as np
from PIL.Image import Image
from matplotlib import pyplot as plt
from utils.dataset_loader import DatasetLoader
from utils.utils import Utils
"""
tflite模型工具类
"""
class TFLiteUtil:
def __init__(self, saved_model_dir, path_url):
self.save_model_dir = saved_model_dir
self.path_url = path_url
# 训练的模型生成标签列表
def get_folder_names(self):
folder_names = []
for root, dirs, files in os.walk(self.path_url + '/train'):
for dir_name in dirs:
folder_names.append(dir_name)
with open(self.save_model_dir + '.label', 'w') as file:
for name in folder_names:
file.write(name + '\n')
return folder_names
# 模型转成tflite格式
def convert_tflite(self):
self.get_folder_names()
converter = tf.lite.TFLiteConverter.from_saved_model(self.save_model_dir)
tflite_model = converter.convert()
# 将转换后的 TFLite 模型保存为文件
with open(self.save_model_dir + '.tflite', 'wb') as f:
f.write(tflite_model)
print("转换成功,已保存为 tflite")
# 加载keras并转成tflite
def convert_model_tflite(self):
self.get_folder_names()
model = keras.models.load_model(self.save_model_dir + ".keras")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
# 将转换后的 TFLite 模型保存为文件
with open(self.save_model_dir + '.tflite', 'wb') as f:
f.wri

最低0.47元/天 解锁文章
1139






