一个pytorch module导出onnx时提示错误:
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
我遇到的这个现象比较特殊:
首先是一个常规的python class类里面包含了一个torch nn.Module对象成员,单独导出这个nn.Module为onnx能够成功。
这个python class的成员函数enhance调用了nn.Module计算,然后使用一个torch nn.Module作为wrapper的forward包装该enhance函数导出,即提示上面的错误。
解决方法:
把这个常规的python class改成torch nn.Module的子类,该错误消失。