Google JAX中的自动微分技术详解

Google JAX中的自动微分技术详解

jax Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more jax 项目地址: https://gitcode.com/gh_mirrors/jax/jax

前言

自动微分(Automatic Differentiation,简称autodiff)是现代机器学习框架的核心技术之一。作为Google JAX的核心功能,自动微分使得梯度计算变得高效而直观。本文将深入探讨JAX中的自动微分机制,帮助读者掌握这一强大工具。

自动微分基础

什么是自动微分

自动微分是一种介于符号微分和数值微分之间的技术,它通过分解计算过程并应用链式法则,能够高效精确地计算导数。与符号微分不同,它不进行表达式展开;与数值微分不同,它没有截断误差。

JAX中的grad函数

JAX提供了jax.grad函数来实现自动微分。这个函数接受一个标量值函数,返回其梯度函数:

import jax.numpy as jnp
from jax import grad

# 计算tanh函数的导数
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))  # 输出tanh在x=2处的导数

高阶导数计算

JAX的一个强大特性是可以轻松计算高阶导数,只需多次应用grad函数:

# 定义函数f(x) = x³ + 2x² - 3x + 1
f = lambda x: x**3 + 2*x**2 - 3*x + 1

# 一阶导数
dfdx = grad(f)
# 二阶导数
d2fdx = grad(dfdx)
# 三阶导数
d3fdx = grad(d2fdx)

print(dfdx(1.0))   # 输出f'(1) = 4
print(d2fdx(1.0))  # 输出f''(1) = 10
print(d3fdx(1.0))  # 输出f'''(1) = 6

实际应用:逻辑回归

让我们通过一个逻辑回归的例子来展示自动微分的实际应用。

模型定义

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# 构建数据集
inputs = jnp.array([[0.52, 1.12, 0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# 定义损失函数(负对数似然)
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

计算梯度

使用grad函数计算权重W和偏置b的梯度:

# 初始化参数
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (3,))
b = jax.random.normal(key, ())

# 计算W的梯度
W_grad = grad(loss)(W, b)

# 计算b的梯度
b_grad = grad(loss, argnums=1)(W, b)

# 同时计算W和b的梯度
W_grad, b_grad = grad(loss, (0, 1))(W, b)

高级特性

处理复杂数据结构

得益于JAX的PyTree抽象,我们可以轻松地对复杂数据结构(如字典、列表、元组等)进行微分:

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

params = {'W': W, 'b': b}
gradients = grad(loss2)(params)

同时计算函数值和梯度

使用value_and_grad可以高效地同时获取函数值和梯度:

from jax import value_and_grad

loss_value, (W_grad, b_grad) = value_and_grad(loss, (0, 1))(W, b)

数值验证

为了验证自动微分结果的正确性,我们可以使用有限差分法进行验证:

eps = 1e-4

# 验证b的梯度
b_grad_numerical = (loss(W, b + eps/2) - loss(W, b - eps/2)) / eps
print('数值微分结果:', b_grad_numerical)
print('自动微分结果:', grad(loss, 1)(W, b))

# JAX还提供了便捷的验证函数
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # 验证到二阶导数

总结

JAX的自动微分系统提供了强大而灵活的功能:

  1. 使用grad函数可以轻松计算一阶和高阶导数
  2. 支持对复杂数据结构进行微分
  3. 提供了value_and_grad等实用函数提高效率
  4. 可以方便地进行数值验证

这些特性使得JAX成为研究和实现机器学习算法的理想工具。对于更高级的自动微分主题,如自定义导数规则等,可以参考JAX的高级自动微分文档。

掌握JAX的自动微分技术,将大大提升你开发和优化机器学习模型的效率。希望本文能帮助你更好地理解和使用这一强大功能。

jax Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more jax 项目地址: https://gitcode.com/gh_mirrors/jax/jax

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

邹岩讳Sally

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值