JAX快速入门指南:30分钟掌握NumPy加速、JIT编译与自动向量化
JAX是一个面向数组的数值计算库,它能够对Python+NumPy程序进行可组合的转换,如微分、向量化、JIT编译到GPU/TPU等,从而实现高效的机器学习研究。本指南将带你在30分钟内快速掌握JAX的核心功能,包括NumPy兼容接口、JIT编译加速、自动微分和自动向量化等,让你的数值计算程序跑得更快、更高效。
安装JAX
JAX可以在Linux、Windows和macOS上直接通过Python Package Index安装CPU版本:
pip install jax
如果你使用NVIDIA GPU,可以安装GPU版本:
pip install -U "jax[cuda12]"
更多详细的平台特定安装信息,请查看安装文档。
JAX与NumPy
JAX的NumPy接口
大多数JAX的使用都是通过熟悉的jax.numpy API,通常将其导入为jnp别名:
import jax.numpy as jnp
通过这个导入,你可以立即以类似于典型NumPy程序的方式使用JAX,包括使用NumPy风格的数组创建函数、Python函数和运算符以及数组属性和方法。
JAX数组与NumPy数组的区别
虽然JAX的接口与NumPy相似,但JAX数组和NumPy数组之间存在一些差异,这些差异在JAX - The Sharp Bits中有详细探讨。
JIT编译加速
为什么需要JIT编译
JAX可以透明地在GPU或TPU上运行(如果没有,则回退到CPU)。然而,在简单使用JAX时,它会一次向芯片发送一个操作的内核。如果有一系列操作,我们可以使用jax.jit函数通过XLA将这一系列操作一起编译,从而提高执行速度。
使用JIT编译函数
我们可以使用IPython的%timeit来快速基准测试函数,使用block_until_ready()来考虑JAX的动态调度(参见异步调度)。以下是一个示例:
from jax import random, jit
key = random.key(1701)
x = random.normal(key, (1_000_000,))
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# 未编译的函数
%timeit selu(x).block_until_ready()
# 使用JIT编译函数
selu_jit = jit(selu)
_ = selu_jit(x) # 第一次调用时编译
%timeit selu_jit(x).block_until_ready()
上述计时代表在CPU上的执行情况,但相同的代码可以在GPU或TPU上运行,通常会获得更大的加速。更多关于JAX中的JIT编译,请查看jit编译文档。
自动微分
基本的自动微分
除了通过JIT编译转换函数外,JAX还提供了其他转换。其中一个转换是jax.grad,它执行自动微分:
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
验证导数
我们可以使用有限差分来验证结果是否正确:
def first_finite_differences(f, x, eps=1E-3):
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
高阶导数和雅可比矩阵
jax.grad和jax.jit转换可以任意组合和混合。除了标量值函数外,jax.jacobian转换可用于计算向量值函数的完整雅可比矩阵:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
对于更高级的自动微分操作,可以使用jax.vjp进行反向模式向量-雅可比乘积,以及jax.jvp和jax.linearize进行前向模式雅可比-向量乘积。这两者可以相互任意组合,也可以与其他JAX转换组合。例如,可以组合它们来创建一个高效计算完整海森矩阵的函数:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
这种组合在实践中会产生高效的代码,这大致就是JAX内置的jax.hessian函数的实现方式。更多关于JAX中的自动微分,请查看自动微分文档。
自动向量化
什么是自动向量化
另一个有用的转换是jax.vmap,即向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是显式循环函数调用,而是将函数转换为本机向量化版本以获得更好的性能。当与jax.jit组合时,它可以与手动重写函数以处理额外的批处理维度一样高效。
使用vmap进行自动向量化
我们将通过一个简单的例子,使用jax.vmap将矩阵-向量乘积提升为矩阵-矩阵乘积。虽然在这个特定情况下手动执行很容易,但相同的技术可以应用于更复杂的函数:
from jax import random, vmap, jit
key1, key2 = random.split(random.key(1701))
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
# 手动批处理
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
# 自动向量化
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
正如你所期望的,jax.vmap可以与jax.jit、jax.grad以及任何其他JAX转换任意组合。更多关于JAX中的自动向量化,请查看自动向量化文档。
JAX的功能远不止于此,它为数值计算和机器学习研究带来了巨大的潜力。希望本快速入门指南能帮助你快速上手JAX,并在实际项目中发挥其强大的作用。如果你想深入了解更多JAX的功能和用法,可以参考官方文档和示例代码。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考




