具有TFLite的Selfie2Anime —第2部分:TFLite模型

这是一个关于如何将TensorFlow 1.x的U-GAT-IT模型转换为TFLite,并在Android上部署以进行自拍到动漫转换的教程。内容包括使用TF1保存模型为SavedModel,从Kaggle下载模型检查点,使用TF2转换为TFLite,运行推理,添加元数据,以及Android上的模型性能基准测试。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Written by ML GDEs Margaret Maynard-Reid and Sayak Paul | Reviewed by Khanh LeViet and Hoi Lam

ML GDEs 玛格丽特·梅纳德·里德萨耶克·保罗撰写 Khanh LeVietHoi Lam评论

This is part 2 of an end-to-end tutorial on how to convert a TF 1.x model to TensorFlow Lite (TFLite), and then deploy it to an Android for transforming an selfie image to a plausible anime. (Part 1 | Part 2 |Part 3) The tutorial is the first of a series of E2E TFLite tutorials of awesome-tflite.

这是有关如何将TF 1.x模型转换为TensorFlow Lite(TFLite),然后将其部署到Android以便将自拍图像转换为合理的动画的端到端教程的第2部分。 ( 第1部分 | 第2部分 | 第3部分 )该教程是awesome-tflite的E2E TFLite系列教程的第一个。

Here is a step-by-step summary:

以下是分步摘要:

  • Generate a SavedModel out of the pre-trained U-GAT-IT model checkpoints.

    从预先训练的U-GAT-IT模型检查点中生成SavedModel

  • Convert SavedModel using the latest TFLiteConverter

    使用最新的TFLiteConverter转换TFLiteConverter

  • Run inference in Python with the converted model

    使用转换后的模型在Python中运行推理
  • Add metadata to enable easy integration with mobile app

    添加元数据以轻松与移动应用程序集成
  • Run model benchmark to make sure the model runs well on mobile

    运行模型基准测试,以确保模型在移动设备上运行良好

使用TF1保存模型-从预先训练的检查点创建SavedModel (Model saving with TF1 — create a SavedModel from pre-trained checkpoints)

Please note that this part needs to run in a TensorFlow 1.x runtime. We used TensorFlow 1.14 because that was the version the model code was written with.

请注意,这部分需要在TensorFlow 1.x运行时中运行。 我们使用TensorFlow 1.14,因为那是编写模型代码的版本。

The U-GAT-IT authors provided the two checkpoints: one extracted after 50 epochs (~4.6GB) and the other extracted after 100 epochs (4.7GB). We will be using a much lighter version from Kaggle, that is suitable for mobile-based deployments.

U-GAT-IT的作者提供了两个检查点:一个检查点是在50个纪元 (约4.6GB)后提取的,另一个是在100个纪元 (4.7GB)后提取的。 我们将使用Kaggle提供的更轻便的版本,该版本适用于基于移动的部署。

从Kaggle下载并提取模型检查点 (Download and extract the model checkpoints from Kaggle)

So, first things first! Let’s download the checkpoints from Kaggle with the Kaggle API. On kaggle.com, go to My Account/API, click on “Create new API token” which triggers the download of kaggle. json, containing your API credentials. Then in Colab, you can specify the following and set the environment variables -

所以,第一件事! 让我们使用Kaggle API从Kaggle下载检查点。 在kaggle.com上,转到“我的帐户/ API”,单击“创建新的API令牌”,这会触发kaggle的下载。 json,其中包含您的API凭据。 然后,在Colab中,您可以指定以下内容并设置环境变量-

Let’s download the checkpoints and extract them-

让我们下载检查点并提取它们-

$ kaggle datasets download -d t04glovern/ugatit-selfie2anime-pretrained
$ unzip -qq /content/ugatit-selfie2anime-pretrained.zip

Load model checkpoints and connect the tensors

加载模型检查点并连接张量

This step usually varies from model to model. A general workflow that is followed in this step is as follows:

此步骤通常因模型而异。 此步骤遵循的一般工作流程如下:

  1. Defining the input and output tensors of the model.

    定义模型的输入和输出张量。
  2. Instantiating the model and connecting the input and the output tensors so that a computation graph can be built.

    实例化模型并连接输入和输出张量,以便可以构建计算图。
  3. Loading the pre-trained checkpoints in the model’s graph.

    将预训练的检查点加载到模型的图形中。
  4. Generate the SavedModel.

    生成SavedModel。

It is worth noting that step 2 in this workflow can vary from model to model so it’s really hard to know that beforehand. For this section, we are going to only focus on the part of the code that is important to understand, for the full implementation, please check out the Colab Notebook that accompanies this tutorial.

值得注意的是,此工作流程中的步骤2可能因模型而异,因此事先很难知道这一点。 在本节中,我们将只专注于重要的代码部分,对于完整的实现,请查看本教程随附的Colab Notebook

In our case, the input and the output tensors and their details can be accessed from an instance of the main model class. So, we will start by instantiating an instance of the UGATIT model class -

在我们的例子中,可以从主模型类的实例访问输入和输出张量及其详细信息。 因此,我们将从实例化UGATIT模型类的实例开始-

with tf.Graph().as_default(), tf.Session() as sess:
   gan = UGATIT(sess, data)
   gan.build_model()
   load_checkpoint(sess, ckpt_path)

data refers to the model configurations as can be seen here. The UGATIT class comes from here. At this point, our model should have been instantiated. Now we need to load the checkpoints into the model via the session into which it is loaded which is what load_checkpoint() method does -

data指的是模型配置可以看出这里UGATIT类来自此处 。 至此,我们的模型应该已经实例化了。 现在,我们需要通过将检查点加载到模型中的会话将其加载到模型中,这就是load_checkpoint()方法的作用-

def load_checkpoint(sess, ckpt_path):
   model_saver = tf.train.Saver(tf.global_variables())
   checkpoint = os.path.expanduser(checkpoint)
   if tf.gfile.IsDirectory(checkpoint):
       checkpoint = tf.train.latest_checkpoint(checkpoint)
       tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))
   model_saver.restore(sess, checkpoint)

At this point, creating the SavedModel needs only a matter of a few keystrokes. Remember that we are still under the Session context.

在这一点上,创建SavedModel只需几次按键即可。 请记住,我们仍处于Session上下文之下。

tf.saved_model.simple_save(
       sess,
       saved_model_dir,
       inputs={gan.test_domain_A.name: gan.test_domain_A},
       outputs={gan.test_fake_B.name: gan.test_fake_B}
)

As we can see in the above code, the input and the output tensors can be accessed from the model graph itself. After this code is executed, we should have the SavedModel files ready. We can proceed with converting this SavedModel to a TFLite model.

正如我们在上面的代码中看到的那样,可以从模型图本身访问输入和输出张量。 执行此代码后,我们应该准备好SavedModel文件。 我们可以继续将此SavedModel转换为TFLite模型。

准备TFLite模型 (Prepare the TFLite model)

Time to shift gears to TensorFlow 2.x (2.2.0 or any higher nightly versions). In this section, we will be using the SavedModel we generated previously and convert it to a TFLite flat buffer, which is about 10 MB in size and perfectly usable in a Mobile Application. Then we will use a few of the latest TensorFlow Lite tools to prepare the model for deployment:

是时候将齿轮换到TensorFlow 2.x(2.2.0或任何更高版本的夜间版本)了。 在本节中,我们将使用之前生成的SavedModel并将其转换为TFLite平面缓冲区,该缓冲区大小约为10 MB ,可以在移动应用程序中完美使用。 然后,我们将使用一些最新的TensorFlow Lite工具来准备要部署的模型:

  • Run inference in Python with the TFLite model to make sure it’s good after the conversion.

    使用TFLite模型在Python中运行推理,以确保转换后效果良好。
  • Add metadata to the TFLite model to make integrating it to an Android app easier with the Android Studio’s ML Model Binding plugin.

    将元数据添加到TFLite模型中,以便通过Android Studio的ML模型绑定插件将其轻松集成到Android应用中。
  • Use the Benchmark tool to see how the model would perform on mobile devices.

    使用基准工具查看该模型在移动设备上的性能。

使用TF2将SavedModel转换为TFLite (Convert SavedModel to TFLite with TF2)

First, we load the SavedModel files and create a concrete function from them -

首先,我们加载SavedModel文件并从中创建一个具体函数 -

model = tf.saved_model.load(saved_model_path)
concrete_func = model.signatures[ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

The advantage of doing the conversion in this way is it gives us the flexibility to set the shapes of the input and output tensors of the resulting TFLite model. You can see this in the following code snippet -

以这种方式进行转换的优点是,它使我们可以灵活地设置所得TFLite模型的输入和输出张量的形状。 您可以在以下代码段中看到这一点-

concrete_func.inputs[0].set_shape([1, 256, 256, 3])
concrete_func.outputs[0].set_shape([1, 256, 256, 3])

It is recommended to use the original shapes of the input and output tensors that were used during training the model accordingly. In this case, this shape is (1, 256, 256, 3) and 1 denotes the batch dimension. This is required because the model expects the data to be in the shape of: BATCH_SIZE, IMAGE_SHAPE, IMAGE_SHAPE, NB_CHANNELS. To do the actual conversion we run the following -

建议使用在训练模型时相应使用的输入和输出张量的原始形状。 在这种情况下,此形状为(1、256、256、3),并且1表示批次尺寸。 这是必需的,因为模型期望数据的形状为:BATCH_SIZE,IMAGE_SHAPE,IMAGE_SHAPE,NB_CHANNELS。 为了进行实际的转换,我们运行以下命令-

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()

Unless we specify any optimization option explicitly to the converter, the model would still be a float model. You can explore the different optimization options available in TFLite from here.

除非我们为converter明确指定任何优化选项,否则该模型仍将是浮点模型。 您可以从此处探索TFLite中可用的不同优化选项。

使用TFLite模型运行推理 (Run Inference with TFLite model)

After the conversion and before deploying the .tflite model, it’s always a good practice to run inference in Python to confirm that it’s working as intended.

在转换之后和部署.tflite模型之前,在Python中运行推理以确认其按预期工作一直是一个好习惯。

We have tried the model on a few faces and it turns out that it produces much better results on female faces than male ones. A closer look at the training dataset reveals that all faces are female faces, and the model bias because the model was trained on only female faces.

我们已经在几个面Kong上尝试了该模型,结果表明,与男性面Kong相比,该模型在女性面Kong上产生的效果要好得多。 仔细查看训练数据集可以发现所有面Kong都是女性面Kong,并且由于模型仅在女性面Kong上进行训练,因此模型存在偏差。

Here is a screen of the test result:

这是测试结果的屏幕:

Image for post

将元数据添加到TFLite模型 (Add Metadata to TFLite Model)

Let’s add metadata to the TensorFlow Model so that we can auto generate model inference code on Android.

让我们将元数据添加到TensorFlow模型中,以便我们可以在Android上自动生成模型推断代码。

Option A: via command line

选项A:通过命令行

If you are adding metadata with the Python script via command line, make sure to first pip install tflite-support in your conda or virtualenv environment. And set up the folder structure as follows:

如果要通过命令行使用Python脚本添加元数据,请确保首先在pip install tflite-support或virtualenv环境中pip install tflite-support 。 并设置文件夹结构如下:

metadata_writer_for_selfie2anime.py
|-- model_without_metadata
|   |--selfie2anime.tflite
|-- model_with_metadata

Then use the metadata_writer_for_selfie2anime.py script to add metadata to the selfie2anime.tflite model:

然后,使用metadata_writer_for_selfie2anime.py脚本将元数据添加到selfie2anime.tflite模型中:

python ./metadata_writer_for_selfie2anime.py \
--model_file=./model_without_metadata/selfie2anime.tflite \
--export_directory=model_with_metadata

Optiona B: via Colab

选项B:通过Colab

Alternatively, you could use this Colab notebook instead. Remember to also first $pip install tflite-support. This option maybe easier for you if you are not familiar with running Python scripts in command line. All you need is to launch the notebook in a browser, upload the selfie2anime.tflite file and execute all cells.

或者,您可以改用此Colab笔记本 。 记住也要先$pip install tflite-support 。 如果您不熟悉在命令行中运行Python脚本,则此选项可能会更方便。 您所需要做的就是在浏览器中启动笔记本电脑,上传selfie2anime.tflite文件并执行所有单元。

Metadata added

添加了元数据

Two new file selfie2anime.tflite and selfie2anime.json are created under the model_with_metadatafolder. This new selfie2anime.tflite contains the model metadata which we can use as input to the ML model Binding in Android Studio when deploying the model to Android. And the selfie2anime.json is for you to verify if the metadata added to the model is correct.

model_with_metadata文件夹下创建了两个新文件selfie2anime.tfliteselfie2anime.json 。 这个新的selfie2anime.tflite包含模型元数据,当将模型部署到Android时,我们可以将其用作Android Studio中ML模型绑定的输入。 selfie2anime.json用于您验证添加到模型中的元数据是否正确。

metadata_writer_for_selfie2anime.py
|-- model_without_metadata
|   |--selfie2anime.tflite
|-- model_with_metadata
|   |--selfie2anime.tflite
|   |--selfie2anime.json

To learn more about how the TFLite metadata works, refer to the documentation here.

要了解有关TFLite元数据如何工作的更多信息,请参阅此处的文档。

Android上的基准模型性能(可选) (Benchmark model perf on Android (Optional))

As an optional step, we used the TFLite Android Model Benchmark tool to get the runtime performance on Android before deploying it. Please refer to the instructions on the benchmark tool for details.

作为可选步骤,我们在部署前使用TFLite Android Model Benchmark工具获取了Android上的运行时性能。 有关详细信息,请参考基准工具上的说明。

Here are the high-level summary steps:

以下是高级摘要步骤:

  • Configure Android NDK/SDK — there are some Android SDK/NDK prerequisites then you build the tool with bazel.

    配置Android NDK / SDK-有一些Android SDK / NDK先决条件,然后使用bazel构建该工具。
  • Build the benchmark apk

    建立基准APK
bazel build -c opt \
    --config=android_arm64 \         
    //tensorflow/lite/tools/benchmark:benchmark_model
  • Use adb (Android Debug Bridge) to install the benchmarking tool and push the selfie2anime.tflite model to Android device:

    使用adb(Android Debug Bridge)安装基准测试工具,并将selfie2anime.tflite模型推入Android设备:
adb install -r -d -g bazel-bin/tensorflow/lite/tools/benchmark/android/benchmark_model.apk
adb push selfie2anime.tflite /data/local/tmp
  • Run the benchmark tool

    运行基准测试工具
adb shell /data/local/tmp/benchmark_model 
  --graph=/data/local/tmp/selfie2anime.tflite 
  --num_threads=4

We see the benchmark result as follows — and it’s a bit slow: Inference timings in us: Init: 7135, First inference: 7428506, Warmup (avg): 7.42851e+06, Inference (avg): 7.26313e+06

我们看到的基准测试结果如下所示-有点慢: 我们的推理时间:初始化:7135,初次推理:7428506,预热(avg):7.42851e + 06,推理(avg):7.26313e + 06

Image for post

Now that you have a TensorFlow Lite model, let’s see how we would implement the model on Android (Part 3).

现在您已经有了一个TensorFlow Lite模型,让我们看看如何在Android上实现该模型( 第3部分 )。

翻译自: https://medium.com/google-developer-experts/selfie2anime-with-tflite-part-2-tflite-model-84002cf521dc

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值