JAX快速入门指南:30分钟掌握NumPy加速、JIT编译与自动向量化

JAX快速入门指南:30分钟掌握NumPy加速、JIT编译与自动向量化

【免费下载链接】jax Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more 【免费下载链接】jax 项目地址: https://gitcode.com/gh_mirrors/jax/jax

JAX是一个面向数组的数值计算库,它能够对Python+NumPy程序进行可组合的转换,如微分、向量化、JIT编译到GPU/TPU等,从而实现高效的机器学习研究。本指南将带你在30分钟内快速掌握JAX的核心功能,包括NumPy兼容接口、JIT编译加速、自动微分和自动向量化等,让你的数值计算程序跑得更快、更高效。

安装JAX

JAX可以在Linux、Windows和macOS上直接通过Python Package Index安装CPU版本:

pip install jax

如果你使用NVIDIA GPU,可以安装GPU版本:

pip install -U "jax[cuda12]"

更多详细的平台特定安装信息,请查看安装文档

JAX与NumPy

JAX的NumPy接口

大多数JAX的使用都是通过熟悉的jax.numpy API,通常将其导入为jnp别名:

import jax.numpy as jnp

通过这个导入,你可以立即以类似于典型NumPy程序的方式使用JAX,包括使用NumPy风格的数组创建函数、Python函数和运算符以及数组属性和方法。

JAX数组与NumPy数组的区别

虽然JAX的接口与NumPy相似,但JAX数组和NumPy数组之间存在一些差异,这些差异在JAX - The Sharp Bits中有详细探讨。

JIT编译加速

为什么需要JIT编译

JAX可以透明地在GPU或TPU上运行(如果没有,则回退到CPU)。然而,在简单使用JAX时,它会一次向芯片发送一个操作的内核。如果有一系列操作,我们可以使用jax.jit函数通过XLA将这一系列操作一起编译,从而提高执行速度。

使用JIT编译函数

我们可以使用IPython的%timeit来快速基准测试函数,使用block_until_ready()来考虑JAX的动态调度(参见异步调度)。以下是一个示例:

from jax import random, jit

key = random.key(1701)
x = random.normal(key, (1_000_000,))

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

# 未编译的函数
%timeit selu(x).block_until_ready()

# 使用JIT编译函数
selu_jit = jit(selu)
_ = selu_jit(x)  # 第一次调用时编译
%timeit selu_jit(x).block_until_ready()

上述计时代表在CPU上的执行情况,但相同的代码可以在GPU或TPU上运行,通常会获得更大的加速。更多关于JAX中的JIT编译,请查看jit编译文档

自动微分

基本的自动微分

除了通过JIT编译转换函数外,JAX还提供了其他转换。其中一个转换是jax.grad,它执行自动微分:

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

验证导数

我们可以使用有限差分来验证结果是否正确:

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

高阶导数和雅可比矩阵

jax.grad和jax.jit转换可以任意组合和混合。除了标量值函数外,jax.jacobian转换可用于计算向量值函数的完整雅可比矩阵:

from jax import jacobian
print(jacobian(jnp.exp)(x_small))

对于更高级的自动微分操作,可以使用jax.vjp进行反向模式向量-雅可比乘积,以及jax.jvp和jax.linearize进行前向模式雅可比-向量乘积。这两者可以相互任意组合,也可以与其他JAX转换组合。例如,可以组合它们来创建一个高效计算完整海森矩阵的函数:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))

这种组合在实践中会产生高效的代码,这大致就是JAX内置的jax.hessian函数的实现方式。更多关于JAX中的自动微分,请查看自动微分文档

自动向量化

什么是自动向量化

另一个有用的转换是jax.vmap,即向量化映射。它具有沿数组轴映射函数的熟悉语义,但不是显式循环函数调用,而是将函数转换为本机向量化版本以获得更好的性能。当与jax.jit组合时,它可以与手动重写函数以处理额外的批处理维度一样高效。

使用vmap进行自动向量化

我们将通过一个简单的例子,使用jax.vmap将矩阵-向量乘积提升为矩阵-矩阵乘积。虽然在这个特定情况下手动执行很容易,但相同的技术可以应用于更复杂的函数:

from jax import random, vmap, jit

key1, key2 = random.split(random.key(1701))
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

# 手动批处理
@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

# 自动向量化
@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

正如你所期望的,jax.vmap可以与jax.jit、jax.grad以及任何其他JAX转换任意组合。更多关于JAX中的自动向量化,请查看自动向量化文档

JAX Logo

JAX的功能远不止于此,它为数值计算和机器学习研究带来了巨大的潜力。希望本快速入门指南能帮助你快速上手JAX,并在实际项目中发挥其强大的作用。如果你想深入了解更多JAX的功能和用法,可以参考官方文档示例代码

【免费下载链接】jax Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more 【免费下载链接】jax 项目地址: https://gitcode.com/gh_mirrors/jax/jax

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值