Tensorflow: tflite_convert and interpreter

本文介绍如何使用两种方法将TensorFlow模型(.pb文件)转换为TFLite模型(.tflite文件),并展示了如何使用TFLite解释器加载模型,预处理图像数据,批量预测图像类别。

一 . Convert ".pb"  to ".tflite"

 

Method 1: Python Convert 

graph_def_file = "retrained_graph_mobilenet_v2_1.4_224.pb"
input_arrays = ["input"]
# input_shape = 
output_arrays = ["final_result"]

converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("retrained_graph_mobilenet_v2_1.4_224_python.tflite", "wb").write(tflite_model)

 

Method 2: CML Convert

IMAGE_SIZE=224
tflite_convert \
  --graph_def_file=tf_files/retrained_graph_mobilenet_v2_1.4_224.pb \
  --output_file=tf_files/retrained_graph_mobilenet_v2_1.4_224_batch3.lite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shape=7,${IMAGE_SIZE},${IMAGE_SIZE},3 \
  --input_array=input \
  --output_array=final_result \
  --inference_type=FLOAT \
  --input_data_type=FLOAT

## or python3 -m tensorflow.contrib.lite.python.tflite_convert \

 

二. Interpreter ".tflite"

#image data pipline:

def read_tensor_from_image_file(file_name, input_height=299, input_width=299, input_mean=0, input_std=255):
  input_name = "file_reader"
  output_name = "normalized"
  file_reader = tf.read_file(file_name, input_name)

  if file_name.endswith(".png"):
    image_reader = tf.image.decode_png(file_reader, channels = 3,
                                       name='png_reader')
  elif file_name.endswith(".gif"):
    image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                  name='gif_reader'))
  elif file_name.endswith(".bmp"):
    image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
  else:
    image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
                                        name='jpeg_reader')
  float_caster = tf.cast(image_reader, tf.float32)
  dims_expander = tf.expand_dims(float_caster, 0);
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session()
  result = sess.run(normalized)

  return result



# args:
input_height = 224
input_width = 224
input_mean = 128
input_std = 128





#---------------------------------
# main script
#---------------------------------

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="tf_lite_for_image/tf_files/retrained_graph_mobilenet_v2_1.4_224_batch3.lite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


file_name_list = ["a.jpg","b.jpg"]

input_batch = []
for file_name in file_name_list:
    temp = read_tensor_from_image_file(file_name,
                                      input_height=input_height,
                                      input_width=input_width,
                                      input_mean=input_mean,
                                      input_std=input_std)
    input_batch.append(np.squeeze(temp))
input_batch = np.array(input_batch)


interpreter.set_tensor(input_details[0]['index'], input_batch)


interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

Note:

specify the input shape for converted lite,  for example  7 images.  It's fast to pass 7 samples through the net to predict the classes:

       --input_shape=7,${IMAGE_SIZE},${IMAGE_SIZE},3 

Then the "file_name_list" should have 7 samples.

the output_data shape is (7, classes_num)

import torch import numpy as np import tensorflow as tf import onnx from onnx_tf.backend import prepare from lightning_wrapper import LightningWrapper from model import EmbedSleepNet def convert_pytorch_to_onnx(ckpt_path, onnx_path, input_shape=(64, 20, 2, 3000)): """将PyTorch Lightning模型转换为ONNX格式""" # 加载模型 model = LightningWrapper.load_from_checkpoint(ckpt_path) model.eval() # 创建输入示例 dummy_input = torch.randn(*input_shape) # 导出为ONNX torch.onnx.export( model.model, dummy_input, onnx_path, export_params=True, opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} ) print(f"模型已导出为ONNX: {onnx_path}") return onnx_path def convert_onnx_to_tensorflow(onnx_path, tf_path='saved_model'): """将ONNX模型转换为TensorFlow SavedModel格式""" # 加载ONNX模型 onnx_model = onnx.load(onnx_path) # 转换为TensorFlow表示 tf_rep = prepare(onnx_model) # 保存为SavedModel tf_rep.export_graph(tf_path) print(f"模型已转换为TensorFlow SavedModel: {tf_path}") return tf_path def convert_tensorflow_to_tflite(tf_path, tflite_path='model.tflite', optimize=False): """将TensorFlow模型转换为TensorFlow Lite格式""" # 加载SavedModel converter = tf.lite.TFLiteConverter.from_saved_model(tf_path) # 可选:优化模型(量化) if optimize: converter.optimizations = [tf.lite.Optimize.DEFAULT] # 转换模型 tflite_model = converter.convert() # 保存为TFLite文件 with open(tflite_path, 'wb') as f: f.write(tflite_model) print(f"模型已转换为TensorFlow Lite: {tflite_path}") return tflite_path def verify_tflite_model(tflite_path, input_shape=(64, 20, 2, 3000)): """验证TFLite模型的输出""" # 加载TFLite模型 interpreter = tf.lite.Interpreter(model_path=tflite_path) interpreter.allocate_tensors() # 获取输入和输出张量信息 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # 准备测试输入 input_data = np.random.random_sample(input_shape).astype(np.float32) # 运行TFLite模型 interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() tflite_output = interpreter.get_tensor(output_details[0]['index']) print(f"TFLite输出形状: {tflite_output.shape}") print(f"示例输出: {tflite_output[0, :5]}") return tflite_output if __name__ == "__main__": # 配置路径 CKPT_PATH = "D:/pycharm/7-Sleep-Stage-Classification-main/Sleep-Stage-Classification-main/EmbedSleepNet/model.ckpt" # PyTorch Lightning模型路径 ONNX_PATH = "D:/pycharm/7-Sleep-Stage-Classification-main/Sleep-Stage-Classification-main/EmbedSleepNet/model.onnx" # ONNX模型保存路径 TF_PATH = "D:/pycharm/7-Sleep-Stage-Classification-main/Sleep-Stage-Classification-main/EmbedSleepNet/saved_model" # TensorFlow SavedModel保存路径 TFLITE_PATH = "D:/pycharm/7-Sleep-Stage-Classification-main/Sleep-Stage-Classification-main/EmbedSleepNet/model.tflite" # TFLite模型保存路径 # 执行转换流程 onnx_path = convert_pytorch_to_onnx(CKPT_PATH, ONNX_PATH) tf_path = convert_onnx_to_tensorflow(onnx_path, TF_PATH) tflite_path = convert_tensorflow_to_tflite(tf_path, TFLITE_PATH, optimize=True) # 验证转换后的模型 verify_tflite_model(tflite_path)
最新发布
07-17
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值