之所以想要用Java来调Pytorch训练的模型,是希望能够充分利用Java的多线程能力,提升模型在生产环境中的运算效率。
Java想要调用Pytorch生成的模型,目前已知的主要有两种:一是将pytorch模型转换成torch.jit.trace或torch.jit.script模型,二是将pytorch模型转换成ONNX(Open Neural Network Exchange,ONNX | Home)模型。Java再通过DJL等库来调用转换后的TorchScript模型或ONNX模型。
首先介绍下TorchScript,以下内容摘自官方网站(PyTorch、Introduction to TorchScript — PyTorch Tutorials 1.11.0+cu102 documentation、Loading a TorchScript Model in C++ — PyTorch Tutorials 1.11.0+cu102 documentation)
TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.
We provide tools to incrementally transition a model from a pure Python program to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.
For a gentle introduction to TorchScript, see the Introduction to TorchScript tutorial.
TorchScript Language
TorchScript is a statically typed subset of Python, so many Python features apply directly to TorchScript. See the full TorchScript Language Reference for details.
Built-in Functions and Modules
TorchScript supports the use of most PyTorch functions and many Python built-ins. See TorchScript Builtins for a full reference of supported functions.
PyTorch Functions and Modules
TorchScript supports a subset of the tensor and neural network functions that PyTorch provides. Most methods on Tensor as well as functions in the torch
namespace, all functions in torch.nn.functional
and most modules from torch.nn
are supported in TorchScript.
See TorchScript Unsupported Pytorch Constructs for a list of unsupported PyTorch functions and modules.
Python Functions and Modules
Many of Python’s built-in functions are supported in TorchScript. The math module is also supported (see math Module for details), but no other Python modules (built-in or third party) are supported.
Python Language Reference Comparison
For a full listing of supported Python features, see Python Language Reference Coverage.
Debugging
Disable JIT for Debugging
PYTORCH_JIT
Setting the environment variable PYTORCH_JIT=0
will disable all script and tracing annotations. If there is hard-to-debug error in one of your TorchScript models, you can use this flag to force everything to run using native Python. Since TorchScript (scripting and tracing) is disabled with this flag, you can use tools like pdb
to debug the model code. For example:
@torch.jit.script
def scripted_fn(x : torch.Tensor):
for i in range(12):
x = x + x
return x
def fn(x):
x = torch.neg(x)
import pdb; pdb.set_trace()
return scripted_fn(x)
traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
Debugging this script with pdb
works except for when we invoke the @torch.jit.script function. We can globally disable JIT, so that we can call the @torch.jit.script function as a normal Python function and not compile it. If the above script is called disable_jit_example.py
, we can invoke it like so:
$ PYTORCH_JIT=0 python disable_jit_example.py
and we will be able to step into the @torch.jit.script function as a normal Python function. To disable the TorchScript compiler for a specific function, see @torch.jit.ignore.
Inspecting Code
TorchScript provides a code pretty-printer for all ScriptModule instances. This pretty-printer gives an interpretation of the script method’s code as valid Python syntax. For example:
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
print(foo.code)
A ScriptModule with a single forward
method will have an attribute code
, which you can use to inspect the ScriptModule’s code. If the ScriptModule has more than one method, you will need to access .code
on the method itself and not the module. We can inspect the code of a method named foo
on a ScriptModule by accessing .foo.code
. The example above produces this output:
def foo(len: int) -> Tensor:
rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
rv0 = rv
for i in range(len):
if torch.lt(i, 10):
rv1 = torch.sub(rv0, 1., 1)
else:
rv1 = torch.add(rv0, 1., 1)
rv0 = rv1
return rv0
This is TorchScript’s compilation of the code for the forward
method. You can use this to ensure TorchScript (tracing or scripting) has captured your model code correctly.
Interpreting Graphs
TorchScript also has a representation at a lower level than the code pretty- printer, in the form of IR graphs.
TorchScript uses a static single assignment (SSA) intermediate representation (IR) to represent computation. The instructions in this format consist of ATen (the C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals. As an example: