C++ API载入tensorflow graph

本文介绍如何使用C++API加载预训练好的TensorFlow图,并将其独立使用或嵌入到其他应用程序中。文章包括安装Bazel、克隆TensorFlow仓库、创建和加载图等步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

通过C++ API载入tensorflow graph

在tensorflow repo中,和C++相关的tutorial远没有python的那么详尽。这篇文章主要介绍如何利用C++来载入一个预训练好的graph,以便于单独使用或者嵌入到其他app中。

Requirements

  • 安装bazel:tensorflow是使用bazel来进行编译的,所以如果要编译其他需要用到tensorflow的文件,我们就需要用到bazel。关于bazel,如果想要了解更多,可以参考我的另外两篇博客:Bazel入门:编译C++项目Bazel入门2:C++编译常见用例

  • Clone TensorFlow repo。

    git clone --recursive https://github.com/tensorflow/tensorflow

构建graph

我们首先创建一个tensorflow graph,然后保存成protobuf备用。

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    a = tf.Variable(5.0, name='a')
    b = tf.Variable(6.0, name='b')
    c = tf.multiply(a, b, name="c")

    sess.run(tf.global_variables_initializer())

    print a.eval() # 5.0
    print b.eval() # 6.0
    print c.eval() # 30.0

    tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)

创建二进制文件

让我们在tensorflow/tensorflow目录下创建一个名叫loader的目录,即tensorflow/tensorflow/loader,用于载入之前我们创建好的graph。

loader/目录下我们再创建一个新的文件叫做loader.cc。在loader.cc里我们要做以下几件事情:

  1. 初始化一个tensorflow session
  2. 载入之前我们创建好的graph
  3. 将这个graph加入到session里面
  4. 设置好输入输出
  5. 运行graph,得到输出
  6. 读取输出中的值
  7. 关闭session,释放资源
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

using namespace tensorflow;

int main(int argc, char* argv[]) {
  // Initialize a tensorflow session
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Read in the protobuf graph we exported
  // (The path seems to be relative to the cwd. Keep this in mind
  // when using `bazel run` since the cwd isn't where you call
  // `bazel run` but from inside a temp folder.)
  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Add the graph to the session
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Setup inputs and outputs:

  // Our graph doesn't require any inputs, since it specifies default values,
  // but we'll change an input to demonstrate.
  Tensor a(DT_FLOAT, TensorShape());
  a.scalar<float>()() = 3.0;

  Tensor b(DT_FLOAT, TensorShape());
  b.scalar<float>()() = 2.0;

  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    { "a", a },
    { "b", b },
  };

  // The session will initialize the outputs
  std::vector<tensorflow::Tensor> outputs;

  // Run the session, evaluating our "c" operation from the graph
  status = session->Run(inputs, {"c"}, {}, &outputs);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Grab the first output (we only evaluated one graph node: "c")
  // and convert the node to a scalar representation.
  auto output_c = outputs[0].scalar<float>();

  // (There are similar methods for vectors and matrices here:
  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

  // Print the results
  std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
  std::cout << output_c() << "\n"; // 30

  // Free any resources used by the session
  session->Close();
  return 0;
}

然后我们需要为我们的项目创建一个BUILD文件,这会告诉bazel要编译什么东西。在BUILD文件里我们要定义一个cc_binary,表示输出一个二进制文件。

cc_binary(
    name = "loader",
    srcs = ["loader.cc"],
    deps = [
        "//tensorflow/core:tensorflow",
    ]
)

那么最终文件结构如下:

  • tensorflow/tensorflow/loader/
  • tensorflow/tensorflow/loader/loader.cc
  • tensorflow/tensorflow/loader/BUILD

编译和运行

  • 在tensorflow repo的根目录下,运行./configure
  • 在tensorflow/tensorflow/loader目录下,运行bazel build :loader
  • 在tensorflow repo的根目录下,cd到 bazel-bin/tensorflow/loader目录下
  • 将graph protobuf 拷贝到models/graph.pb
  • 运行./loader,得到输出!

Reference

  1. Loading a TensorFlow graph with the C++ API
  2. tensorflow#issue:Packaged TensorFlow C++ library for bazel-independent use
### TensorFlow 1.12 C++ API 文档及相关资源 TensorFlow 的官方文档提供了详细的说明来帮助开发者了解其 C++ API。对于特定版本(如 TensorFlow 1.12),可以访问对应的存档页面获取相关文档和示例。 #### 官方文档链接 TensorFlow 提供了一个专门用于存储旧版文档的区域,可以通过以下方式找到所需版本的文档: - 访问 [TensorFlow 版本控制](https://www.tensorflow.org/versions)[^4] 页面。 - 找到并点击对应版本号(即 `v1.12`)进入该版本的具体文档页面。 #### 使用方法与示例代码 以下是关于如何使用 TensorFlow 1.12 中 C++ API 的一些基本指导: 1. **加载模型** 加载已保存的 TensorFlow 图形文件通常通过 `tensorflow::Session` 实现。下面是一个简单的例子展示如何加载 `.pb` 文件中的图结构以及执行推理操作。 ```cpp #include "tensorflow/cc/client/client_session.h" #include "tensorflow/core/framework/tensor.h" using namespace tensorflow; // Load graph definition from file. Status load_graph(const string& graph_file_name, std::unique_ptr<Session>* session) { GraphDef graph_def; TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), graph_file_name, &graph_def)); session->reset(NewSession(SessionOptions())); TF_RETURN_IF_ERROR((*session)->Create(graph_def)); return Status::OK(); } int main() { auto session = std::unique_ptr<Session>(); const char* model_path = "/path/to/model.pb"; // Step to initialize the session with loaded graph. Status s = load_graph(model_path, &session); if (!s.ok()) { LOG(ERROR) << "Error loading computation graph: " << s.ToString(); return -1; } // Define input tensor(s). Tensor input_tensor(DT_FLOAT, TensorShape({})); // Set values into 'input_tensor'... // Run inference using ClientSession. ClientSession client(session.get()); std::vector<Tensor> outputs; client.Run({{"input_node", input_tensor}}, {"output_node"}, &outputs); return 0; } ``` 上述程序片段展示了如何利用 TensorFlowC++ 接口完成从图形定义读取至实际运行预测的过程[^5]。 #### 已知问题及解决方案建议 如果遇到兼容性或其他技术难题时,可尝试查阅社区支持平台上的讨论帖或者提交新的议题请求协助解决。例如 GitHub Issues 和 Stack Overflow 都是非常活跃的技术交流场所,在那里能够获得来自全球开发者的及时反馈和支持。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值