【深度学习框架】JAX:高效的数值计算与深度学习框架

部署运行你感兴趣的模型镜像

1. 什么是 JAX?

JAX 是由 Google Research 开发的 高性能数值计算库,主要用于 机器学习、深度学习科学计算。它基于 NumPy 的 API,但提供了 自动微分(Autograd)XLA 编译加速高效的 GPU/TPU 计算,使其成为 TensorFlow 和 PyTorch 的强劲竞争者


2. JAX 的核心特点

1️⃣ 自动微分(Autograd)

  • JAX 提供 前向(Forward-mode)反向(Reverse-mode) 自动微分,适用于各种梯度计算任务,如 深度学习、强化学习、物理模拟 等。

2️⃣ JIT 编译(加速计算)

  • JAX 使用 XLA(Accelerated Linear Algebra) 进行 Just-In-Time(JIT)编译,大幅提升计算速度,类似 TensorFlow 的 Graph Execution。

3️⃣ 并行计算(Vectorization & GPU/TPU 加速)

  • vmap(自动向量化):自动将标量操作向量化,提高效率。
  • pmap(多 GPU/TPU 并行化):轻松实现数据并行计算。

4️⃣ 兼容 NumPy

  • JAX 的 API 设计 类似 NumPy,但所有计算都是 不可变(Immutable) 的,适合 函数式编程

3. 安装 JAX

JAX 可以通过 pip 安装:

# 仅支持 CPU
pip install jax

# 支持 GPU(CUDA 版本)
pip install jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

注意:如果使用 GPU,需要安装 正确版本的 CUDA 和 cuDNN


4. JAX 基本用法

1️⃣ JAX 作为 NumPy 替代

JAX 提供 jax.numpy,它的 API 近似于 NumPy,但支持 GPU/TPU 加速:

import jax.numpy as jnp

# 创建 JAX 数组(不可变)
x = jnp.array([1.0, 2.0, 3.0])

# 矩阵运算(在 GPU 上计算)
y = jnp.dot(x, x)
print(y)  # 14.0


2️⃣ 自动微分(grad)

JAX 提供 jax.grad() 计算标量函数的梯度:

import jax

# 定义函数:y = x^2
def f(x):
    return x ** 2

# 计算导数 dy/dx
grad_f = jax.grad(f)

# 在 x=3 处求导
print(grad_f(3.0))  # 6.0

⚠️ 重要说明
  • grad(f) 只能用于标量输出(如损失函数)。
  • 如果是 向量输出,可以使用 jax.jacfwd()jax.jacrev() 计算 雅可比矩阵

3️⃣ JIT 编译(加速计算)

使用 jax.jit() 可以 即时编译 代码,提高计算速度:

import jax.numpy as jnp
from jax import jit

# 定义函数
def slow_func(x):
    return jnp.sin(x) + jnp.cos(x)

# JIT 编译加速
fast_func = jit(slow_func)

# 测试计算
import time

x = jnp.linspace(0, 10, 1000)

# 普通计算
start = time.time()
slow_func(x).block_until_ready()  # 确保计算完成
print("Normal:", time.time() - start)

# JIT 编译后计算
start = time.time()
fast_func(x).block_until_ready()
print("JIT Compiled:", time.time() - start)

 运行结果

Normal: 0.04390454292297363
JIT Compiled: 0.01696181297302246


4️⃣ 向量化计算(vmap)

使用 jax.vmap() 自动向量化 计算,提高批量处理效率:

from jax import vmap
import jax.numpy as jnp

# 定义标量函数
def f(x):
    return x ** 2

# 直接计算(需要手写 for 循环)
x = jnp.array([1.0, 2.0, 3.0])
print(f(x))  # 错误!f 只能处理标量

# 使用 vmap 自动向量化
vectorized_f = vmap(f)
print(vectorized_f(x))  # [1.0, 4.0, 9.0]

运行结果 

[1. 4. 9.]
[1. 4. 9.]


5️⃣ 并行计算(pmap)

使用 jax.pmap() 进行 多 GPU/TPU 并行计算

from jax import pmap
import jax.numpy as jnp

# 定义一个简单的计算函数
def f(x):
    return x ** 2 + 2 * x + 1

# 在多个设备上并行计算
parallel_f = pmap(f)

# 输入数据(假设有 2 张 GPU)
x = jnp.array([1.0, 2.0])
print(parallel_f(x))  # [4.0, 9.0]


5. 使用 JAX 训练深度学习模型

使用 JAX 训练一个简单的 逻辑回归模型

import jax.numpy as jnp
import jax
from jax import grad, jit
from jax.scipy.special import expit as sigmoid

# 生成随机数据
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 2))  # 100 个样本,每个 2 维
y = (X[:, 0] + X[:, 1] > 0).astype(jnp.float32)  # 线性分类任务

# 初始化权重
w = jax.random.normal(key, (2,))
b = 0.0

# 定义损失函数(交叉熵)
def loss_fn(w, b, X, y):
    logits = jnp.dot(X, w) + b
    return -jnp.mean(y * jnp.log(sigmoid(logits)) + (1 - y) * jnp.log(1 - sigmoid(logits)))

# 计算梯度
grad_fn = grad(loss_fn)

# 训练循环
lr = 0.1
for i in range(100):
    grads = grad_fn(w, b, X, y)
    w -= lr * grads[0]
    b -= lr * grads[1]

print("训练完成:", w, b)

 运行结果

训练完成: [1.4514779 3.0926886] -0.20143564


6. JAX vs. PyTorch vs. TensorFlow

特性JAXPyTorchTensorFlow
计算方式函数式命令式符号式+命令式
GPU/TPU✅ 强大✅ 强大✅ 强大
自动微分✅ 强大(Autograd)✅(torch.autograd)✅(tf.GradientTape)
JIT 编译✅(XLA)✅(XLA)
并行计算pmap❌ 需要 DDPtf.distribute
适用场景数学、优化、强化学习、深度学习深度学习、CV、NLP大规模 AI 训练

7. 总结

  • JAX 适合需要高性能计算的 AI 研究,尤其是 强化学习、物理模拟、自动微分优化 等任务。
  • JIT 编译、自动微分和 GPU/TPU 并行化 让 JAX 比 NumPy、PyTorch 更高效。
  • JAX 代码风格简洁,与 NumPy 兼容,但 学习曲线比 PyTorch/TensorFlow 稍陡峭

JAX 是一个未来 AI 计算的重要工具,适用于高效数值计算和深度学习,尤其适合 Google Cloud、TPU 和科学计算 领域!

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值