[ESP32]:TFLite Micro推理CIFAR10模型

本文介绍了如何在ESP32平台上使用TensorFlowLiteMicro进行CIFAR10数据集的2D卷积神经网络模型训练,随后将模型转换为TFLite格式,并提供了两种推理方案:PC端JPEG图片转换为数组和ESP32本地JPEG解码。

[ESP32]:TFLite Micro推理CIFAR10模型

模型训练

数据集处理

from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, load_model, Model
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D, GlobalAveragePooling2D
import pandas as pd
import matplotlib.pyplot as plt
import time, pickle
from keras.utils import to_categorical
from keras import layers
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

y_train = y_train.reshape(y_train.shape[0])
y_test = y_test.reshape(y_test.shape[0])

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'training samples')
print(x_test.shape[0], 'validation samples')

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

模型搭建

def build_model():
    model = tf.keras.Sequential(
<think>我们正在为ESP32构建TensorFlowLiteMicro库,并准备在MicroPython中使用。以下是详细步骤:1.首先,我们需要构建适用于ESP32TFLiteMicro库。由于ESP32是Xtensa架构,我们需要针对该架构进行编译。2.然后,我们将构建的库集成到MicroPython的编译过程中,使得MicroPython可以调用TFLiteMicro的C++API。注意:由于MicroPython主要是用C编写的,而TFLiteMicro是C++库,因此我们需要使用extern"C"来封装接口,以便在MicroPython中调用。步骤概览:1.构建TFLiteMicroforESP322.创建MicroPython的ESP32端口,并将TFLiteMicro库集成进去3.编写MicroPython的C扩展模块,提供TFLiteMicro的Python接口4.编译整个项目由于整个过程较为复杂,我们只给出关键步骤和代码示例。详细步骤:###1.构建TFLiteMicroforESP32我们使用ESP-IDF的构建系统来构建TFLiteMicro。首先,将TFLiteMicro作为ESP-IDF的一个组件。在ESP-IDF项目中的components目录下,创建一个名为`tflite-micro`的组件。组件目录结构:```components/tflite-micro/├──include│└──tensorflow├──CMakeLists.txt└──...(源文件和依赖)```我们可以从TensorFlow的仓库中获取TFLiteMicro的源代码,并编写CMakeLists.txt。由于TFLiteMicro需要FlatBuffers,我们还需要将FlatBuffers作为另一个组件。这里我们简化处理,假设已经将TFLiteMicro的源代码放在`components/tflite-micro`目录下。在`components/tflite-micro/CMakeLists.txt`中,添加源文件并设置包含目录。示例CMakeLists.txt:```cmakeset(tflite_micro_src"tensorflow/lite/micro/all_ops_resolver.cc""tensorflow/lite/micro/micro_interpreter.cc"...#其他需要的源文件)idf_component_register(SRCS${tflite_micro_src}INCLUDE_DIRS"include"REQUIRESflatbuffers)```###2.为MicroPython创建C扩展模块我们需要创建一个MicroPython的C模块,该模块将提供TFLiteMicro的接口。例如,我们可以创建一个名为`tflite`的模块。在MicroPython的esp32端口的模块目录中(例如`ports/esp32/modules`),创建一个新目录`tflite`,其中包含:-`tflite.c`:模块的实现-`tflite.h`:头文件-`micropython.mk`:构建规则在`tflite.c`中,我们需要使用extern"C"来包含TFLiteMicro的头文件,并编写MicroPython绑定函数。示例`tflite.c`:```c#include"py/runtime.h"#include"py/obj.h"//使用extern"C"以避免C++名称修饰#ifdef__cplusplusextern"C"{#endif#include"tensorflow/lite/micro/micro_interpreter.h"#include"tensorflow/lite/micro/all_ops_resolver.h"#ifdef__cplusplus}#endif//定义MicroPython对象来表示TFLite模型和解释器typedefstruct_mp_tflite_interpreter_obj_t{mp_obj_base_tbase;tflite::MicroInterpreter*interpreter;}mp_tflite_interpreter_obj_t;//创建解释器的函数STATICmp_obj_ttflite_create_interpreter(mp_obj_tmodel_data){//获取模型数据的缓冲区mp_buffer_info_tmodel_buf;mp_get_buffer_raise(model_data,&model_buf,MP_BUFFER_READ);//创建解释器statictflite::AllOpsResolverresolver;constexprintkTensorArenaSize=100*1024;//根据实际调整staticuint8_ttensor_arena[kTensorArenaSize];tflite::MicroInterpreter*interpreter=newtflite::MicroInterpreter(tflite::GetModel(model_buf.buf),resolver,tensor_arena,kTensorArenaSize);//分配张量TfLiteStatusstatus=interpreter->AllocateTensors();if(status!=kTfLiteOk){mp_raise_msg(&mp_type_RuntimeError,"Failedtoallocatetensors");}//创建解释器对象mp_tflite_interpreter_obj_t*obj=m_new_obj(mp_tflite_interpreter_obj_t);obj->base.type=&mp_tflite_interpreter_type;//需要定义类型obj->interpreter=interpreter;returnMP_OBJ_FROM_PTR(obj);}STATICMP_DEFINE_CONST_FUN_OBJ_1(tflite_create_interpreter_obj,tflite_create_interpreter);//模块函数表STATICconstmp_rom_map_elem_ttflite_module_globals_table[]={{MP_ROM_QSTR(MP_QSTR_create_interpreter),MP_ROM_PTR(&tflite_create_interpreter_obj)},};STATICMP_DEFINE_CONST_DICT(tflite_module_globals,tflite_module_globals_table);//模块定义constmp_obj_module_ttflite_module={.base={&mp_type_module},.globals=(mp_obj_dict_t*)&tflite_module_globals,};//注册模块MP_REGISTER_MODULE(MP_QSTR_tflite,tflite_module);```###3.将TFLiteMicro模块添加到MicroPython的ESP32构建在MicroPython的esp32端口的`main.c`中,我们需要初始化这个模块。在`mp_init()`函数中,添加:```cexternconstmp_obj_module_ttflite_module;mp_obj_dict_t*mp_module_tflite_globals=tflite_module.globals;mp_map_init(&mp_module_tflite_globals->map,...);```或者,更简单的方法是在`mpconfigport.h`中启用该模块:```cexternconststruct_mp_obj_module_ttflite_module;#defineMICROPY_PORT_BUILTIN_MODULES\{MP_ROM_QSTR(MP_QSTR_tflite),MP_ROM_PTR(&tflite_module)},```###4.编译整个项目使用ESP-IDF的构建系统编译整个项目:```bashidf.pybuild```注意:由于TFLiteMicro需要较大的内存,确保在ESP32上分配足够的存储空间(如使用PSRAM)。###5.在MicroPython中使用在MicroPython中,我们可以这样使用:```pythonimporttflite#加载模型数据(例如从文件中读取)withopen('model.tflite','rb')asf:model_data=f.read()interpreter=tflite.create_interpreter(model_data)```以上是一个高度简化的流程,实际中需要处理更多的细节,例如输入和输出张量的访问、错误处理等。由于整个过程非常复杂,这里只提供了关键步骤的概述。具体实现时,可能需要根据实际情况调整。注意:TensorFlowLiteMicro库本身较大,需要确保ESP32的Flash和RAM足够。相关问题:§§相关问题§§:1.如何减少TFLiteMicro库的大小以适应ESP32的Flash限制?2.在MicroPython中如何访问TFLiteMicro解释器的输入和输出张量?3.如何为ESP32启用TFLiteMicro的硬件加速(如ESP-NN)?4.如何将TFLite模型转换为适用于MicroPython的C数组?</think>以下是为 ESP32 构建 TensorFlow Lite Micro 库并在 MicroPython 中使用的完整解决方案: --- ### 整体流程 ```mermaid graph TD A[准备环境] --> B[构建 TFLite Micro 库] B --> C[创建 MicroPython 固件] C --> D[集成 TFLiteMicroPython] D --> E[编写 MicroPython 接口] E --> F[部署到 ESP32] ``` --- ### 1. 准备构建环境 ```bash # 安装依赖 sudo apt-get install git wget flex bison gperf python3 python3-pip cmake ninja-build ccache libffi-dev libssl-dev # 克隆 ESP-IDF (v5.1+) git clone -b v5.1 https://github.com/espressif/esp-idf.git cd esp-idf ./install.sh source export.sh # 克隆 MicroPython git clone https://github.com/micropython/micropython.git cd micropython git submodule update --init ``` --- ### 2. 构建 TFLite Micro 库 ```bash # 克隆 TFLite Micro git clone https://github.com/tensorflow/tflite-micro.git cd tflite-micro # 构建 ESP32 专用库 make -f tensorflow/lite/micro/tools/make/Makefile \ TARGET=esp \ TARGET_ARCH=xtensa \ XTENSA_CORE=esp32 \ microlite # 生成的库路径 ls tensorflow/lite/micro/tools/make/gen/esp_xtensa/ ``` --- ### 3. 集成到 MicroPython #### 创建自定义模块 ```c // 在 micropython/ports/esp32 下创建 modtflite.c #include "py/runtime.h" #include "tensorflow/lite/micro/micro_interpreter.h" // MicroPython 对象表示 TFLite 解释器 typedef struct _mp_tflite_obj_t { mp_obj_base_t base; tflite::MicroInterpreter* interpreter; } mp_tflite_obj_t; // 创建解释器的函数 STATIC mp_obj_t tflite_create_interpreter(mp_obj_t model_data) { // 获取模型数据缓冲区 mp_buffer_info_t bufinfo; mp_get_buffer_raise(model_data, &bufinfo, MP_BUFFER_READ); // 创建解释器 static tflite::AllOpsResolver resolver; constexpr int kTensorArenaSize = 50 * 1024; static uint8_t tensor_arena[kTensorArenaSize]; tflite::MicroInterpreter* interpreter = new tflite::MicroInterpreter( tflite::GetModel(bufinfo.buf), resolver, tensor_arena, kTensorArenaSize); // 分配张量 if (interpreter->AllocateTensors() != kTfLiteOk) { mp_raise_msg(&mp_type_RuntimeError, "Tensor allocation failed"); } // 创建 Python 对象 mp_tflite_obj_t *obj = m_new_obj(mp_tflite_obj_t); obj->base.type = &mp_tflite_type; obj->interpreter = interpreter; return MP_OBJ_FROM_PTR(obj); } STATIC MP_DEFINE_CONST_FUN_OBJ_1(tflite_create_interpreter_obj, tflite_create_interpreter); // 模块定义 STATIC const mp_rom_map_elem_t tflite_module_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_tflite) }, { MP_ROM_QSTR(MP_QSTR_create_interpreter), MP_ROM_PTR(&tflite_create_interpreter_obj) }, }; STATIC MP_DEFINE_CONST_DICT(tflite_module_globals, tflite_module_globals_table); const mp_obj_module_t mp_module_tflite = { .base = { &mp_type_module }, .globals = (mp_obj_dict_t *)&tflite_module_globals, }; ``` #### 修改构建配置 在 `micropython/ports/esp32/mpconfigport.h` 中添加: ```c // 启用 TFLite 模块 extern const struct _mp_obj_module_t mp_module_tflite; #define MICROPY_PORT_BUILTIN_MODULES \ { MP_ROM_QSTR(MP_QSTR_tflite), MP_ROM_PTR(&mp_module_tflite) }, ``` 在 `micropython/ports/esp32/main.c` 中添加头文件: ```c #include "modtflite.h" ``` --- ### 4. 构建 MicroPython 固件 ```bash cd micropython/ports/esp32 # 配置构建选项 idf.py set-target esp32 idf.py menuconfig # 在配置中启用: # Component config -> MicroPython -> Enable TFLite support # 添加 TFLite 库路径到 CMakeLists.txt list(APPEND EXTRA_COMPONENT_DIRS /path/to/tflite-micro/tensorflow/lite/micro/tools/make/gen/esp_xtensa/lib ) # 构建固件 idf.py build ``` --- ### 5. MicroPython 使用示例 ```python import tflite import ubinascii # 加载模型 (转换为字节数组) with open('model.tflite', 'rb') as f: model_data = f.read() # 创建解释器 interp = tflite.create_interpreter(model_data) # 获取输入张量 input_tensor = interp.get_input_tensor(0) # 填充输入数据 (示例) input_data = bytearray([0]*784) # MNIST 输入尺寸 input_tensor.write_data(input_data) # 执行推理 interp.invoke() # 获取输出 output_tensor = interp.get_output_tensor(0) results = output_tensor.read_data() print("Prediction:", results.index(max(results))) ``` --- ### 关键优化技巧 1. **内存管理**: ```python # 预分配内存池 from micropython import alloc_emergency_exception_buf alloc_emergency_exception_buf(100) ``` 2. **模型压缩**: ```bash # 使用 xxd 将模型转换为 C 数组 xxd -i model.tflite > model.h ``` 3. **性能监控**: ```python import time start = time.ticks_us() interp.invoke() delta = time.ticks_diff(time.ticks_us(), start) print(f"Inference time: {delta} us") ``` --- ### 部署到 ESP32 ```bash # 刷写固件 idf.py -p /dev/ttyUSB0 flash # 上传模型文件 ampy -p /dev/ttyUSB0 put model.tflite ``` ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值