1.问题描述
使用JAX
时返回了nan
:
from jax import numpy as jnp
x = jnp.divide(0., 0.)
print("x value:", x)
显示如下:
这个时候其实没办法定位问题在哪里,因为你写的代码是成百上千行的!!!
2.解决方式:
from jax import numpy as jnp
from jax.config import config
config.update("jax_debug_nans", True)
x = jnp.divide(0., 0.)
print("x value:", x)
显示如下:
因此可以加上上面的语法以定位发生错误的位置,而不是返回一个nan
值!