Google JAX中的自动微分技术详解
jax Python+NumPy程序的可组合变换功能:进行求导、矢量化、JIT编译至GPU/TPU及其他更多操作 项目地址: https://gitcode.com/gh_mirrors/ja/jax
前言
自动微分(Automatic Differentiation,简称autodiff)是现代机器学习框架的核心技术之一。Google JAX作为一个高性能数值计算库,其自动微分系统设计精巧且功能强大。本文将深入解析JAX中的自动微分机制,帮助读者掌握这一关键技术。
自动微分基础
自动微分与符号微分和数值微分不同,它通过分解计算过程,应用链式法则来自动计算导数。JAX采用反向模式自动微分(反向传播),特别适合处理多输入、单输出的函数梯度计算。
基本梯度计算
JAX中最基础的自动微分函数是jax.grad
,它可以计算标量值函数的梯度:
import jax
import jax.numpy as jnp
# 计算tanh函数的导数
grad_tanh = jax.grad(jnp.tanh)
print(grad_tanh(2.0)) # 输出tanh在x=2.0处的导数值
jax.grad
接受一个函数并返回其梯度函数。如果f
是一个Python函数,那么jax.grad(f)
就是计算∇f的函数。
高阶导数计算
JAX的一个强大特性是可以轻松计算高阶导数,只需连续应用jax.grad
:
# 定义函数f(x) = x³ + 2x² - 3x + 1
f = lambda x: x**3 + 2*x**2 - 3*x + 1
# 计算各阶导数
dfdx = jax.grad(f) # 一阶导数
d2fdx = jax.grad(dfdx) # 二阶导数
d3fdx = jax.grad(d2fdx) # 三阶导数
d4fdx = jax.grad(d3fdx) # 四阶导数
# 在x=1处评估各阶导数
print(dfdx(1.0)) # 输出4.0
print(d2fdx(1.0)) # 输出10.0
print(d3fdx(1.0)) # 输出6.0
print(d4fdx(1.0)) # 输出0.0
实际应用:线性逻辑回归
让我们通过一个线性逻辑回归的例子,展示JAX自动微分在实际问题中的应用。
模型定义
首先定义模型和损失函数:
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# 定义损失函数(负对数似然)
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
参数梯度计算
使用jax.grad
计算损失函数对参数的梯度:
# 计算损失函数对W的梯度
W_grad = jax.grad(loss)(W, b)
# 计算损失函数对b的梯度
b_grad = jax.grad(loss, argnums=1)(W, b)
# 同时计算对W和b的梯度
W_grad, b_grad = jax.grad(loss, (0, 1))(W, b)
argnums
参数允许我们指定要对哪些参数求导,非常灵活。
处理复杂数据结构
JAX的PyTree机制使其能够自然地处理嵌套的Python数据结构,如列表、元组和字典。
字典参数示例
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# 计算对字典参数的梯度
gradients = jax.grad(loss2)({'W': W, 'b': b})
这种设计使得代码组织更加灵活,可以轻松处理复杂的模型参数结构。
高效计算:同时获取函数值和梯度
在实际优化过程中,我们通常需要同时计算函数值和梯度。JAX提供了jax.value_and_grad
函数来高效地完成这一任务:
loss_value, gradients = jax.value_and_grad(loss)(W, b)
这种方式避免了重复计算,显著提高了效率。
数值验证
为了确保自动微分计算的正确性,我们可以使用有限差分法进行验证:
# 设置微小步长
eps = 1e-4
# 对偏置b进行数值梯度检验
b_grad_numerical = (loss(W, b + eps/2) - loss(W, b - eps/2)) / eps
# 对权重W进行数值梯度检验(随机方向)
vec = jax.random.normal(key, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps/2*unitvec, b) - loss(W - eps/2*unitvec, b)) / eps
JAX还提供了便捷的梯度检查函数:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # 检查至二阶导数
进阶主题
虽然本文涵盖了JAX自动微分的基础用法,但JAX还支持更多高级特性:
- 自定义导数规则:可以为原生Python函数定义自定义的导数规则
- 前向模式自动微分:适合少输入多输出的情况
- Jacobian和Hessian计算:使用
jax.jacfwd
和jax.jacrev
等函数 - 隐函数微分:处理隐式定义的函数
这些高级主题将在后续文章中详细介绍。
总结
JAX的自动微分系统设计精良,具有以下特点:
- 简单易用:通过
jax.grad
等函数提供直观的接口 - 灵活强大:支持高阶导数、复杂数据结构和自定义规则
- 高效可靠:计算结果精确,性能优异
掌握JAX的自动微分技术,将大大提升您在机器学习、科学计算等领域的开发效率和模型性能。
jax Python+NumPy程序的可组合变换功能:进行求导、矢量化、JIT编译至GPU/TPU及其他更多操作 项目地址: https://gitcode.com/gh_mirrors/ja/jax
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考