版本
在命令窗口输入nvidia-smi查看
NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8
找到对应的版本https://jax.readthedocs.io/en/latest/installation.html
pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
那么就是
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
等待安装
检查是否安装成功
import jax
import jax.numpy as jnp