TensorFlowInferenceInterface.java在java中的改编

本文介绍了一个针对 TensorFlow 的 Java 推理接口实现,该接口简化了模型加载及推理过程,支持从不同来源加载模型,并提供了丰富的 API 来进行输入输出操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlowInferenceInterface.java类在java中的改编,技术有限。

package org.tensorflow.examples;

import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;

/**
 * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
 * for inference.
 *
 * <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
 * example usage.
 */
public class ori_tensorflowinterface {
  private static final String TAG = "TensorFlowInferenceInterface";
  private static final String ASSET_FILE_PREFIX = "dir_of_pb";

  /*
   * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
   *
   * @param assetManager The AssetManager to use to load the model file.
   * @param model The filepath to the GraphDef proto representing the model.
   */
  public ori_tensorflowinterface(String model) throws IOException {

    this.modelName = model;
    this.g = new Graph();
    this.sess = new Session(g);
    this.runner = sess.runner();
    InputStream is = null;
    is = new FileInputStream(model);

    byte[] graphDef = new byte[is.available()];
    final int numBytesRead = is.read(graphDef);
    loadGraph(graphDef, g);
    is.close();
  }

  /*
   * Load a TensorFlow model from provided InputStream.
   * Note: The InputStream will not be closed after loading model, users need to
   * close it themselves.
   *
   * @param is The InputStream to use to load the model.
   */
  public ori_tensorflowinterface(InputStream is) throws IOException {

    // modelName is redundant for model loading from input stream, here is for
    // avoiding error in initialization as modelName is marked final.
    this.modelName = "";
    this.g = new Graph();
    this.sess = new Session(g);
    this.runner = sess.runner();

      int baosInitSize = is.available() > 16384 ? is.available() : 16384;
      ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
      int numBytesRead;
      byte[] buf = new byte[16384];
      while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
        baos.write(buf, 0, numBytesRead);
      }
      byte[] graphDef = baos.toByteArray();
      loadGraph(graphDef, g);
  }

  /*
   * Construct a TensorFlowInferenceInterface with provided Graph
   *
   * @param g The Graph to use to construct this interface.
   */
  public ori_tensorflowinterface(Graph g) {

    // modelName is redundant here, here is for
    // avoiding error in initialization as modelName is marked final.
    this.modelName = "";
    this.g = g;
    this.sess = new Session(g);
    this.runner = sess.runner();
  }

  /**
   * Runs inference between the previously registered input nodes (via feed*) and the requested
   * output nodes. Output nodes can then be queried with the fetch* methods.
   *
   * @param outputNames A list of output nodes which should be filled by the inference pass.
   */
  public void run(String[] outputNames) {
    run(outputNames, false);
  }

  /**
   * Runs inference between the previously registered input nodes (via feed*) and the requested
   * output nodes. Output nodes can then be queried with the fetch* methods.
   *
   * @param outputNames A list of output nodes which should be filled by the inference pass.
   */
  public void run(String[] outputNames, boolean enableStats) {
    // Release any Tensors from the previous run calls.
    closeFetches();

    // Add fetches.
    for (String o : outputNames) {
      fetchNames.add(o);
      TensorId tid = TensorId.parse(o);
      runner.fetch(tid.name, tid.outputIndex);
    }
    // Run the session.
      fetchTensors = runner.run();
      closeFeeds();
      runner = sess.runner();
   }


  /** Returns a reference to the Graph describing the computation run during inference. */
  public Graph graph() {
    return g;
  }

  public Operation graphOperation(String operationName) {
    final Operation operation = g.operation(operationName);
    if (operation == null) {
      throw new RuntimeException(
          "Node '" + operationName + "' does not exist in model '" + modelName + "'");
    }
    return operation;
  }

  /** Returns the last stat summary string if logging is enabled. */
  /**
   * Cleans up the state associated with this Object.
   *
   * <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
   */
  public void close() {
    closeFeeds();
    closeFetches();
    sess.close();
    g.close();
  }

  @Override
  protected void finalize() throws Throwable {
    try {
      close();
    } finally {
      super.finalize();
    }
  }

  // Methods for taking a native Tensor and filling it with values from Java arrays.

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, boolean[] src, long... dims) {
    byte[] b = new byte[src.length];

    for (int i = 0; i < src.length; i++) {
      b[i] = src[i] ? (byte) 1 : (byte) 0;
    }

    addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
  }

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, float[] src, long... dims) {
    addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
  }

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, int[] src, long... dims) {
    addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
  }

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, long[] src, long... dims) {
    addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
  }

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, double[] src, long... dims) {
    addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
  }

  /**
   * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
   * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
   * as many elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, byte[] src, long... dims) {
    addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
  }

  /**
   * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
   * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
   * a Java {@code String} (which is a sequence of characters).
   */
  public void feedString(String inputName, byte[] src) {
    addFeed(inputName, Tensors.create(src));
  }

  /**
   * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
   * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
   * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
   */
  public void feedString(String inputName, byte[][] src) {
    addFeed(inputName, Tensors.create(src));
  }

  // Methods for taking a native Tensor and filling it with src from Java native IO buffers.

  /**
   * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
   * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
   * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
   * elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, FloatBuffer src, long... dims) {
    addFeed(inputName, Tensor.create(dims, src));
  }

  /**
   * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
   * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
   * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
   * elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, IntBuffer src, long... dims) {
    addFeed(inputName, Tensor.create(dims, src));
  }

  /**
   * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
   * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
   * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
   * elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, LongBuffer src, long... dims) {
    addFeed(inputName, Tensor.create(dims, src));
  }

  /**
   * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
   * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
   * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
   * elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, DoubleBuffer src, long... dims) {
    addFeed(inputName, Tensor.create(dims, src));
  }

  /**
   * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
   * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
   * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
   * elements as that of the destination Tensor. If {@link src} has more elements than the
   * destination has capacity, the copy is truncated.
   */
  public void feed(String inputName, ByteBuffer src, long... dims) {
    addFeed(inputName, Tensor.create(UInt8.class, dims, src));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
   * dst} must have length greater than or equal to that of the source Tensor. This operation will
   * not affect dst's content past the source Tensor's size.
   */
  public void fetch(String outputName, float[] dst) {
    fetch(outputName, FloatBuffer.wrap(dst));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
   * dst} must have length greater than or equal to that of the source Tensor. This operation will
   * not affect dst's content past the source Tensor's size.
   */
  public void fetch(String outputName, int[] dst) {
    fetch(outputName, IntBuffer.wrap(dst));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
   * dst} must have length greater than or equal to that of the source Tensor. This operation will
   * not affect dst's content past the source Tensor's size.
   */
  public void fetch(String outputName, long[] dst) {
    fetch(outputName, LongBuffer.wrap(dst));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
   * dst} must have length greater than or equal to that of the source Tensor. This operation will
   * not affect dst's content past the source Tensor's size.
   */
  public void fetch(String outputName, double[] dst) {
    fetch(outputName, DoubleBuffer.wrap(dst));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
   * dst} must have length greater than or equal to that of the source Tensor. This operation will
   * not affect dst's content past the source Tensor's size.
   */
  public void fetch(String outputName, byte[] dst) {
    fetch(outputName, ByteBuffer.wrap(dst));
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
   * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
   * or equal to that of the source Tensor. This operation will not affect dst's content past the
   * source Tensor's size.
   */
  public void fetch(String outputName, FloatBuffer dst) {
    getTensor(outputName).writeTo(dst);
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
   * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
   * or equal to that of the source Tensor. This operation will not affect dst's content past the
   * source Tensor's size.
   */
  public void fetch(String outputName, IntBuffer dst) {
    getTensor(outputName).writeTo(dst);
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
   * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
   * or equal to that of the source Tensor. This operation will not affect dst's content past the
   * source Tensor's size.
   */
  public void fetch(String outputName, LongBuffer dst) {
    getTensor(outputName).writeTo(dst);
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
   * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
   * or equal to that of the source Tensor. This operation will not affect dst's content past the
   * source Tensor's size.
   */
  public void fetch(String outputName, DoubleBuffer dst) {
    getTensor(outputName).writeTo(dst);
  }

  /**
   * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
   * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
   * or equal to that of the source Tensor. This operation will not affect dst's content past the
   * source Tensor's size.
   */
  public void fetch(String outputName, ByteBuffer dst) {
    getTensor(outputName).writeTo(dst);
  }


  private void loadGraph(byte[] graphDef, Graph g) throws IOException {
    final long startMs = System.currentTimeMillis();

    try {
      g.importGraphDef(graphDef);
    } catch (IllegalArgumentException e) {
      throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
    }

    final long endMs = System.currentTimeMillis();
  }

  private void addFeed(String inputName, Tensor<?> t) {
    // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
    TensorId tid = TensorId.parse(inputName);
    runner.feed(tid.name, tid.outputIndex, t);
    feedNames.add(inputName);
    feedTensors.add(t);
  }

  private static class TensorId {
    String name;
    int outputIndex;

    // Parse output names into a TensorId.
    //
    // E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
    public static TensorId parse(String name) {
      TensorId tid = new TensorId();
      int colonIndex = name.lastIndexOf(':');
      if (colonIndex < 0) {
        tid.outputIndex = 0;
        tid.name = name;
        return tid;
      }
      try {
        tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
        tid.name = name.substring(0, colonIndex);
      } catch (NumberFormatException e) {
        tid.outputIndex = 0;
        tid.name = name;
      }
      return tid;
    }
  }

  private Tensor<?> getTensor(String outputName) {
    int i = 0;
    for (String n : fetchNames) {
      if (n.equals(outputName)) {
        return fetchTensors.get(i);
      }
      ++i;
    }
    throw new RuntimeException(
        "Node '" + outputName + "' was not provided to run(), so it cannot be read");
  }

  private void closeFeeds() {
    for (Tensor<?> t : feedTensors) {
      t.close();
    }
    feedTensors.clear();
    feedNames.clear();
  }

  private void closeFetches() {
    for (Tensor<?> t : fetchTensors) {
      t.close();
    }
    fetchTensors.clear();
    fetchNames.clear();
  }

  // Immutable state.
  private final String modelName;
  private final Graph g;
  private final Session sess;

  // State reset on every call to run.
  private Session.Runner runner;
  private List<String> feedNames = new ArrayList<String>();
  private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
  private List<String> fetchNames = new ArrayList<String>();
  private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();

}
<think>好的,用户的问题是关于如何通过Java调用espcn_x3.pb模型文件,需要知道引入哪些依赖项以及具体实现方法。首先,我需要回忆一下Java调用TensorFlow模型的一般步骤。 根据用户提供的引用,特别是引用[2],里面提到了使用TensorFlowInferenceInterface的方法。这可能涉及到TensorFlow的Java API。用户提到的espcn_x3.pb应该是一个预训练好的模型文件,用于图像超分辨率任务,比如ESPCN模型。 接下来,我需要考虑依赖项。TensorFlow的Java版本通常需要引入相关的库,比如org.tensorflow:tensorflow,可能还需要Protobuf的支持,因为模型文件是.pb格式,属于Protocol Buffers。引用[4]提到Protobuf的语法和相关文档,所以可能需要protobuf-java的依赖。 然后,具体的实现步骤。引用[2]中的四步法:创建TensorFlowInferenceInterface,feed输入数据,运行模型,fetch输出。需要确保输入节点的名称和维度与模型匹配,输出节点也是如此。用户可能需要加载模型,准备输入数据(比如图像数组),运行推理,然后获取输出结果。 另外,需要注意模型文件的位置。在Android中,可能放在assets目录下,但如果是普通的Java应用,可能需要从文件系统加载。用户的问题中没有明确是Android还是其他Java环境,所以可能需要分别说明。 还有,版本兼容性也是一个问题。TensorFlow的Java库版本需要与生成pb模型的TensorFlow版本兼容,否则可能会有错误。用户需要确保版本匹配。 可能还需要处理图像预处理,比如将图像转换为模型所需的输入格式,以及后处理,将输出结果转换为可视化的图像。这部分可能需要使用Java的图像处理库,比如BufferedImage操作。 总结一下,步骤大概是:引入依赖(TensorFlow和Protobuf),加载模型,准备输入数据,执行推理,处理输出。需要确保节点名称正确,数据格式匹配,版本兼容。 现在,用户可能遇到的问题包括依赖冲突、找不到模型文件、节点名称错误、输入维度不匹配等,需要在回答中提醒这些注意事项。</think>要使用Java调用`espcn_x3.pb`模型文件,需遵循以下步骤: --- ### 一、依赖项配置 1. **TensorFlow Java依赖** 在`pom.xml`中添加TensorFlow核心库和Java API依赖(需匹配模型训练时的TensorFlow版本): ```xml <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.15.0</version> <!-- 根据实际版本调整 --> </dependency> ``` 2. **Protobuf支持** 若需直接操作Protobuf数据(如解析模型结构),添加Protobuf依赖: ```xml <dependency> <groupId>com.google.protobuf</groupId> <artifactId>protobuf-java</artifactId> <version>3.11.4</version> </dependency> ``` --- ### 二、实现步骤 #### 1. 加载模型文件 将`espcn_x3.pb`模型文件放置在资源目录(如`src/main/resources/models`),通过`ClassLoader`读取: ```java try (InputStream modelStream = getClass().getResourceAsStream("/models/espcn_x3.pb")) { byte[] modelBytes = IOUtils.toByteArray(modelStream); try (TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(modelBytes)) { // 后续操作 } } ``` #### 2. 输入数据预处理 假设模型输入为图像张量,需将图像转换为`float[]`数组并归一化: ```java BufferedImage image = ImageIO.read(new File("input.jpg")); float[] inputPixels = preprocessImage(image); // 自定义预处理方法 ``` #### 3. 执行推理 参考TensorFlow Android接口设计[^2]: ```java // 输入/输出节点名称需与模型定义匹配(可通过Netron查看模型结构) String inputNode = "input:0"; String outputNode = "output:0"; inferenceInterface.feed(inputNode, inputPixels, 1, height, width, channels); // 维度顺序:NCHW/NHWC inferenceInterface.run(outputNode); Tensor outputTensor = inferenceInterface.fetch(outputNode); float[] outputData = new float[outputTensor.numElements()]; outputTensor.copyTo(outputData); ``` #### 4. 输出后处理 将`outputData`转换为高分辨率图像: ```java BufferedImage outputImage = postProcessOutput(outputData); // 自定义后处理方法 ImageIO.write(outputImage, "JPEG", new File("output.jpg")); ``` --- ### 三、注意事项 1. **模型兼容性** 确保TensorFlow Java版本与模型训练环境一致,否则可能触发`org.tensorflow.TensorFlowException`[^1]。 2. **输入维度对齐** 通过工具(如Netron)确认输入张量的形状(如`[batch, height, width, channels]`)。 3. **性能优化** 对于大尺寸图像,建议分块处理或使用GPU加速(需引入`libtensorflow_jni_gpu`)。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值