创建新环境比较好
conda create -n myjax python=3.11
进入环境
conda activate myjax
我的电脑只有cpu所以安装最新版本为
pip install jax[cpu]==0.4.35 jaxlib==0.4.35
等待安装
检查是否成功且查看版本
python -c "import jax; print(jax.__version__)"
应该没有报错并显示版本0.4.35
接下来尝试一个简单的运算
python -c "import jax.numpy as jnp; print(jnp.add(1, 1))"
应该算出来是2且没有报错
安装成功,可以用自己的python运算了