PyTorch模型生产部署与C++应用实践
1. JIT模块特性与模型处理
JIT编译后的模块在训练和推理方面与Python中的表现一致。在训练时,我们可以像使用传统模型一样使用它们;而在推理时,也需要像传统模型那样设置,例如使用 torch.no_grad() 上下文来确保其正常工作。
对于算法相对简单的模型,如CycleGAN、分类模型和基于U-Net的分割模型,我们可以像之前那样对模型进行跟踪(tracing)。而对于更复杂的模型,有一个实用的特性:我们可以在其他脚本化或跟踪的代码中使用脚本化或跟踪的函数,并且在构建、跟踪或脚本化模块时使用脚本化或跟踪的子模块。我们还可以通过调用 nn.Models 来跟踪函数,但需要将所有参数设置为不需要梯度,因为对于跟踪的模型,这些参数将作为常量。
2. 可跟踪性的脚本化处理
在更复杂的模型中,如Fast R-CNN系列的检测模型或自然语言处理中使用的循环网络,带有控制流(如 for 循环)的部分需要进行脚本化处理。以U-Net模型为例, UNetUpBlock 类中的 center_crop 函数在跟踪时会遇到问题。
class UNetUpBlock(nn.Module):
# ...
def center_crop(self, layer, target_size):
_, _, layer_height, layer_width = layer.
超级会员免费看
订阅专栏 解锁全文
6705

被折叠的 条评论
为什么被折叠?



