Java调用PyTorch下基于BERT等预训练模型生成的复杂模型的一般流程

之所以想要用Java来调Pytorch训练的模型,是希望能够充分利用Java的多线程能力,提升模型在生产环境中的运算效率。

Java想要调用Pytorch生成的模型,目前已知的主要有两种:一是将pytorch模型转换成torch.jit.trace或torch.jit.script模型,二是将pytorch模型转换成ONNX(Open Neural Network Exchange,ONNX | Home)模型。Java再通过DJL等库来调用转换后的TorchScript模型或ONNX模型。

首先介绍下TorchScript,以下内容摘自官方网站(PyTorchIntroduction to TorchScript — PyTorch Tutorials 1.11.0+cu102 documentationLoading a TorchScript Model in C++ — PyTorch Tutorials 1.11.0+cu102 documentation

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

We provide tools to incrementally transition a model from a pure Python program to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.

For a gentle introduction to TorchScript, see the Introduction to TorchScript tutorial.

TorchScript Language

TorchScript is a statically typed subset of Python, so many Python features apply directly to TorchScript. See the full TorchScript Language Reference for details.

Built-in Functions and Modules

TorchScript supports the use of most PyTorch functions and many Python built-ins. See TorchScript Builtins for a full reference of supported functions.

PyTorch Functions and Modules

TorchScript supports a subset of the tensor and neural network functions that PyTorch provides. Most methods on Tensor as well as functions in the torch namespace, all functions in torch.nn.functional and most modules from torch.nn are supported in TorchScript.

See TorchScript Unsupported Pytorch Constructs for a list of unsupported PyTorch functions and modules.

Python Functions and Modules

Many of Python’s built-in functions are supported in TorchScript. The math module is also supported (see math Module for details), but no other Python modules (built-in or third party) are supported.

Python Language Reference Comparison

For a full listing of supported Python features, see Python Language Reference Coverage.

Debugging

Disable JIT for Debugging

PYTORCH_JIT

Setting the environment variable PYTORCH_JIT=0 will disable all script and tracing annotations. If there is hard-to-debug error in one of your TorchScript models, you can use this flag to force everything to run using native Python. Since TorchScript (scripting and tracing) is disabled with this flag, you can use tools like pdb to debug the model code. For example:

@torch.jit.script
def scripted_fn(x : torch.Tensor):
    for i in range(12):
        x = x + x
    return x

def fn(x):
    x = torch.neg(x)
    import pdb; pdb.set_trace()
    return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

Debugging this script with pdb works except for when we invoke the @torch.jit.script function. We can globally disable JIT, so that we can call the @torch.jit.script function as a normal Python function and not compile it. If the above script is called disable_jit_example.py, we can invoke it like so:

$ PYTORCH_JIT=0 python disable_jit_example.py

and we will be able to step into the @torch.jit.script function as a normal Python function. To disable the TorchScript compiler for a specific function, see @torch.jit.ignore.

Inspecting Code

TorchScript provides a code pretty-printer for all ScriptModule instances. This pretty-printer gives an interpretation of the script method’s code as valid Python syntax. For example:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

ScriptModule with a single forward method will have an attribute code, which you can use to inspect the ScriptModule’s code. If the ScriptModule has more than one method, you will need to access .code on the method itself and not the module. We can inspect the code of a method named foo on a ScriptModule by accessing .foo.code. The example above produces this output:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

This is TorchScript’s compilation of the code for the forward method. You can use this to ensure TorchScript (tracing or scripting) has captured your model code correctly.

Interpreting Graphs

TorchScript also has a representation at a lower level than the code pretty- printer, in the form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation (IR) to represent computation. The instructions in this format consist of ATen (the C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals. As an example:


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值