Google JAX中的自动微分技术详解
前言
自动微分(Automatic Differentiation,简称autodiff)是现代机器学习框架的核心技术之一。作为Google JAX的核心功能,自动微分使得梯度计算变得高效而直观。本文将深入探讨JAX中的自动微分机制,帮助读者掌握这一强大工具。
自动微分基础
什么是自动微分
自动微分是一种介于符号微分和数值微分之间的技术,它通过分解计算过程并应用链式法则,能够高效精确地计算导数。与符号微分不同,它不进行表达式展开;与数值微分不同,它没有截断误差。
JAX中的grad函数
JAX提供了jax.grad
函数来实现自动微分。这个函数接受一个标量值函数,返回其梯度函数:
import jax.numpy as jnp
from jax import grad
# 计算tanh函数的导数
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) # 输出tanh在x=2处的导数
高阶导数计算
JAX的一个强大特性是可以轻松计算高阶导数,只需多次应用grad函数:
# 定义函数f(x) = x³ + 2x² - 3x + 1
f = lambda x: x**3 + 2*x**2 - 3*x + 1
# 一阶导数
dfdx = grad(f)
# 二阶导数
d2fdx = grad(dfdx)
# 三阶导数
d3fdx = grad(d2fdx)
print(dfdx(1.0)) # 输出f'(1) = 4
print(d2fdx(1.0)) # 输出f''(1) = 10
print(d3fdx(1.0)) # 输出f'''(1) = 6
实际应用:逻辑回归
让我们通过一个逻辑回归的例子来展示自动微分的实际应用。
模型定义
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# 构建数据集
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# 定义损失函数(负对数似然)
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
计算梯度
使用grad
函数计算权重W和偏置b的梯度:
# 初始化参数
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (3,))
b = jax.random.normal(key, ())
# 计算W的梯度
W_grad = grad(loss)(W, b)
# 计算b的梯度
b_grad = grad(loss, argnums=1)(W, b)
# 同时计算W和b的梯度
W_grad, b_grad = grad(loss, (0, 1))(W, b)
高级特性
处理复杂数据结构
得益于JAX的PyTree抽象,我们可以轻松地对复杂数据结构(如字典、列表、元组等)进行微分:
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
params = {'W': W, 'b': b}
gradients = grad(loss2)(params)
同时计算函数值和梯度
使用value_and_grad
可以高效地同时获取函数值和梯度:
from jax import value_and_grad
loss_value, (W_grad, b_grad) = value_and_grad(loss, (0, 1))(W, b)
数值验证
为了验证自动微分结果的正确性,我们可以使用有限差分法进行验证:
eps = 1e-4
# 验证b的梯度
b_grad_numerical = (loss(W, b + eps/2) - loss(W, b - eps/2)) / eps
print('数值微分结果:', b_grad_numerical)
print('自动微分结果:', grad(loss, 1)(W, b))
# JAX还提供了便捷的验证函数
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # 验证到二阶导数
总结
JAX的自动微分系统提供了强大而灵活的功能:
- 使用
grad
函数可以轻松计算一阶和高阶导数 - 支持对复杂数据结构进行微分
- 提供了
value_and_grad
等实用函数提高效率 - 可以方便地进行数值验证
这些特性使得JAX成为研究和实现机器学习算法的理想工具。对于更高级的自动微分主题,如自定义导数规则等,可以参考JAX的高级自动微分文档。
掌握JAX的自动微分技术,将大大提升你开发和优化机器学习模型的效率。希望本文能帮助你更好地理解和使用这一强大功能。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考