Google JAX中的自动微分技术详解

Google JAX中的自动微分技术详解

jax Python+NumPy程序的可组合变换功能:进行求导、矢量化、JIT编译至GPU/TPU及其他更多操作 jax 项目地址: https://gitcode.com/gh_mirrors/ja/jax

前言

自动微分(Automatic Differentiation,简称autodiff)是现代机器学习框架的核心技术之一。Google JAX作为一个高性能数值计算库,其自动微分系统设计精巧且功能强大。本文将深入解析JAX中的自动微分机制,帮助读者掌握这一关键技术。

自动微分基础

自动微分与符号微分和数值微分不同,它通过分解计算过程,应用链式法则来自动计算导数。JAX采用反向模式自动微分(反向传播),特别适合处理多输入、单输出的函数梯度计算。

基本梯度计算

JAX中最基础的自动微分函数是jax.grad,它可以计算标量值函数的梯度:

import jax
import jax.numpy as jnp

# 计算tanh函数的导数
grad_tanh = jax.grad(jnp.tanh)
print(grad_tanh(2.0))  # 输出tanh在x=2.0处的导数值

jax.grad接受一个函数并返回其梯度函数。如果f是一个Python函数,那么jax.grad(f)就是计算∇f的函数。

高阶导数计算

JAX的一个强大特性是可以轻松计算高阶导数,只需连续应用jax.grad

# 定义函数f(x) = x³ + 2x² - 3x + 1
f = lambda x: x**3 + 2*x**2 - 3*x + 1

# 计算各阶导数
dfdx = jax.grad(f)        # 一阶导数
d2fdx = jax.grad(dfdx)    # 二阶导数
d3fdx = jax.grad(d2fdx)   # 三阶导数
d4fdx = jax.grad(d3fdx)   # 四阶导数

# 在x=1处评估各阶导数
print(dfdx(1.0))   # 输出4.0
print(d2fdx(1.0))  # 输出10.0
print(d3fdx(1.0))  # 输出6.0
print(d4fdx(1.0))  # 输出0.0

实际应用:线性逻辑回归

让我们通过一个线性逻辑回归的例子,展示JAX自动微分在实际问题中的应用。

模型定义

首先定义模型和损失函数:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# 定义损失函数(负对数似然)
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

参数梯度计算

使用jax.grad计算损失函数对参数的梯度:

# 计算损失函数对W的梯度
W_grad = jax.grad(loss)(W, b)

# 计算损失函数对b的梯度
b_grad = jax.grad(loss, argnums=1)(W, b)

# 同时计算对W和b的梯度
W_grad, b_grad = jax.grad(loss, (0, 1))(W, b)

argnums参数允许我们指定要对哪些参数求导,非常灵活。

处理复杂数据结构

JAX的PyTree机制使其能够自然地处理嵌套的Python数据结构,如列表、元组和字典。

字典参数示例

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# 计算对字典参数的梯度
gradients = jax.grad(loss2)({'W': W, 'b': b})

这种设计使得代码组织更加灵活,可以轻松处理复杂的模型参数结构。

高效计算:同时获取函数值和梯度

在实际优化过程中,我们通常需要同时计算函数值和梯度。JAX提供了jax.value_and_grad函数来高效地完成这一任务:

loss_value, gradients = jax.value_and_grad(loss)(W, b)

这种方式避免了重复计算,显著提高了效率。

数值验证

为了确保自动微分计算的正确性,我们可以使用有限差分法进行验证:

# 设置微小步长
eps = 1e-4

# 对偏置b进行数值梯度检验
b_grad_numerical = (loss(W, b + eps/2) - loss(W, b - eps/2)) / eps

# 对权重W进行数值梯度检验(随机方向)
vec = jax.random.normal(key, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps/2*unitvec, b) - loss(W - eps/2*unitvec, b)) / eps

JAX还提供了便捷的梯度检查函数:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # 检查至二阶导数

进阶主题

虽然本文涵盖了JAX自动微分的基础用法,但JAX还支持更多高级特性:

  1. 自定义导数规则:可以为原生Python函数定义自定义的导数规则
  2. 前向模式自动微分:适合少输入多输出的情况
  3. Jacobian和Hessian计算:使用jax.jacfwdjax.jacrev等函数
  4. 隐函数微分:处理隐式定义的函数

这些高级主题将在后续文章中详细介绍。

总结

JAX的自动微分系统设计精良,具有以下特点:

  1. 简单易用:通过jax.grad等函数提供直观的接口
  2. 灵活强大:支持高阶导数、复杂数据结构和自定义规则
  3. 高效可靠:计算结果精确,性能优异

掌握JAX的自动微分技术,将大大提升您在机器学习、科学计算等领域的开发效率和模型性能。

jax Python+NumPy程序的可组合变换功能:进行求导、矢量化、JIT编译至GPU/TPU及其他更多操作 jax 项目地址: https://gitcode.com/gh_mirrors/ja/jax

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卓滨威Delmar

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值