JAX
文章平均质量分 81
JAX 文档
csdddn
技术搬运工
展开
专栏收录文章
- 默认排序
- 最新发布
- 最早发布
- 最多阅读
- 最少阅读
-
JAX - 默认数据类型和 X64 标志
JAX 在默认数据类型设置上平衡了科学计算(偏好高精度)和AI研究(侧重速度)两类用户的需求。默认情况下,jax_enable_x64标志为False,此时所有数值操作使用32位类型,且禁止创建64位数组。用户可通过jax.config.update('jax_enable_x64', True)或环境变量JAX_ENABLE_X64=1全局启用64位支持。这种全局设置目前无法局部修改,但JAX团队正在探索更灵活的方案。该设计体现了JAX在精度与性能之间的权衡考量。原创 2025-10-23 02:10:06 · 271 阅读 · 0 评论 -
JAX - Autodidax:从零实现JAX核心
本文介绍了如何从零实现JAX核心系统,重点讲解了变换机制的工作原理。主要内容包括: 通过解释器栈实现变换机制,将变换视为对原始操作的不同解释方式 设计了Primitive、Trace和Tracer等核心类,实现操作拦截和变换处理 解释了bind函数如何根据当前激活的解释器决定应用哪个变换规则 展示了如何通过上下文管理器管理解释器栈,实现变换的组合和嵌套 提供了基础数据类型和运算符的实现,如ShapedArray和基本数学运算 文章还提到这是正在进行中的工作,后续将补充自动微分和vmap等更复杂的功能实现。原创 2025-10-21 03:33:27 · 906 阅读 · 0 评论 -
JAX - GPU 内存分配
JAX默认预分配75%的GPU内存以优化性能,但可能导致内存不足(OOM)错误。可以通过调整环境变量来修改内存分配策略:禁用预分配、调整预分配比例或使用按需分配模式。常见OOM原因包括:运行多个JAX进程、与TensorFlow并行运行、在显示GPU上运行等。实验性功能如cudaMallocAsync可作为替代方案,但存在碎片化和性能不稳定的风险。文章提供了针对不同场景的详细配置建议和解决方案。原创 2025-10-21 03:34:07 · 868 阅读 · 0 评论 -
JAX - 基于 JAX 的开发
本文介绍了JAX在不同场景下的应用方式及其优势。JAX的核心功能包括高效的梯度计算、单核多设备加速以及并行计算能力,这些特性使其广泛应用于机器学习优化、状态空间模型等场景。文章展示了三种集成JAX的方式:直接使用底层API、组合领域专用库(如Flax和Optax搭配),以及完全封装JAX(如PyMC)。这些灵活的使用模式使JAX既适合从零开发,也能无缝融入现有工作流,为开发者提供了高效的计算加速方案。原创 2025-10-21 03:35:06 · 381 阅读 · 0 评论 -
JAX - Rank promotion(秩提升)警告
本文介绍了JAX中NumPy广播规则的秩提升(rank promotion)行为及其潜在风险。秩提升允许数组维度自动扩展,虽然方便但可能掩盖形状错误。JAX提供了三种配置选项:allow(默认允许)、warn(警告)和raise(报错)。用户可通过上下文管理器局部设置,或通过jax.config、环境变量、命令行参数全局配置。这有助于开发者更安全地使用广播特性,避免意外错误。原创 2025-10-21 03:33:42 · 244 阅读 · 0 评论 -
JAX - 常见问题解答(FAQ)
JAX常见问题解答摘要:本文解答了JAX使用中的常见问题,包括jax.jit()的行为变化、数值结果差异、编译速度慢等问题。重点解释了全局状态和副作用对JIT的影响,XLA优化导致的数值差异,以及如何优化编译慢的函数。对于类方法使用JIT,提供了三种策略:辅助函数、标记静态参数和注册PyTree类型。每个问题都附有代码示例和解决方案,帮助用户更好地理解和使用JAX。原创 2025-10-21 03:34:15 · 900 阅读 · 0 评论 -
JAX - 在 JAX 中编写自定义 Jaxpr 解释器
本文介绍了如何在JAX中编写自定义Jaxpr解释器,通过实现一个简单的"逆变换器"示例来展示其工作原理。首先解释了JAX如何通过Jaxpr tracer将Python函数转换为中间表示,然后详细说明了如何构建inverse解释器:追踪函数生成Jaxpr,创建primitive逆运算注册表,最后反向遍历Jaxpr方程并应用逆运算。这种自定义解释器可以与其他JAX变换组合使用,但需要注意该实现使用了JAX内部API,可能会随时变化。文章通过具体代码展示了从函数追踪到最终实现的完整过程,为理解原创 2025-10-21 03:33:51 · 748 阅读 · 0 评论 -
JAX - 并发
JAX对Python并发支持有限:允许在不同线程中调用JAX API(如jit()或grad()),但禁止多线程同时操作trace值。具体而言,虽然可以在多线程中调用使用JAX tracing的函数,但不能在被jit()装饰的函数内部使用线程操作JAX值,否则可能导致难以诊断的错误。这种限制是为了确保JAX tracing过程的线程安全性。原创 2025-10-21 03:33:10 · 173 阅读 · 0 评论 -
JAX - 回归问题调查
摘要:JAX性能回退调查指南 本文介绍了如何调查JAX版本升级后出现的性能回退问题。主要方法是通过暴力测试定位问题版本:首先使用NVIDIA每日构建容器进行日期级定位,然后通过小时级commit检测缩小范围,最后验证具体commit。文章提供了详细的测试脚本示例,包括docker容器操作、版本切换和性能测试流程。这种方法适用于快速复现的场景,能有效减少XLA重新编译次数,同时保持JAX和XLA版本兼容性。对于更复杂的崩溃问题,建议使用二分法定位。原创 2025-10-21 03:34:00 · 625 阅读 · 0 评论 -
JAX - 异步调度
JAX采用异步调度机制隐藏Python开销,允许程序在加速器计算完成前继续执行。当执行操作如jnp.dot(x, x)时,JAX立即返回一个jax.Array未来值对象,实际计算在后台进行。只有需要检查结果(如打印或转换为numpy数组)时才会等待。这种机制提高了效率,但会导致微基准测试不准确,因为测量的只是任务派发时间而非实际计算时间。要准确测量性能,必须使用block_until_ready()或强制结果传回主机。原创 2025-10-21 03:33:35 · 288 阅读 · 0 评论 -
JAX - Python 和 NumPy 版本支持政策
JAX 遵循 Python 科学社区的版本支持政策,但延长了对 Python 的支持周期。具体政策包括:支持发布前45个月内的Python版本(如3.11支持至2026年7月)、24个月内的NumPy版本(如2.0支持至2026年6月)和24个月内的SciPy版本(如1.13支持至2026年4月)。JAX可能支持比政策要求更老的版本,但超过上述期限后可能随时移除旧版本支持。原创 2025-10-21 03:34:26 · 365 阅读 · 0 评论 -
JAX - API兼容性
JAX API兼容性政策摘要:JAX采用基于工作量的版本管理,目前处于零版本阶段,允许小版本引入破坏性变更。公共API(如jax、jax.numpy等)受保障,遵循3个月弃用周期。明确私有API、jaxlib、实验性模块和数值精确值不受保障。伪随机函数输出值可能变更,但分布特性保持稳定。建议用户避免依赖私有API,并关注变更公告。JAX团队正逐步规范化API边界,提高稳定性预期。原创 2025-10-18 00:02:29 · 549 阅读 · 0 评论 -
JAX - 分布式数组与自动并行化
本教程介绍了JAX中通过jax.Array实现的分布式计算功能。主要内容包括: jax.Array作为统一数组对象模型,支持跨设备分布式存储 使用Sharding对象描述数组在设备间的内存布局,通过NamedSharding子类定义分布方式 编译器自动根据输入数据分布进行并行计算,包括元素操作和矩阵乘法等 计算结果会保留输入数据的分布特性,无需额外通信 通过实验展示了分布式计算相比单设备计算的性能优势 该功能需要至少8个设备支持,适合大规模并行计算场景。原创 2025-10-18 00:02:24 · 660 阅读 · 0 评论 -
JAX - 关键概念
JAX 是一个用于高性能数值计算的 Python 库,其核心特性包括:1)转换函数(如 jit、vmap、grad),通过装饰器语法实现即时编译、自动向量化和自动微分;2)追踪机制,通过 Tracer 对象在编译前提取函数操作序列;3)JAXPR 中间表示,用于精确编码函数操作;4)Pytree 数据结构,统一处理数组集合;5)分层 API 设计,包含高级的 jax.numpy 和底层的 jax.lax。这些特性使 JAX 能够高效执行数值计算并支持自动微分和硬件加速。原创 2025-10-18 00:02:16 · 951 阅读 · 0 评论 -
JAX - 类型提升语义
本文详细介绍了JAX的类型提升规则体系,重点分析了其与NumPy的核心差异。JAX采用基于类型格的提升机制,针对现代GPU/TPU加速器优化了数值类型处理,优先保持低精度类型而非默认提升到64位。特别说明了bfloat16等特殊浮点类型的处理方式,以及弱类型值对提升行为的影响。文章还介绍了严格类型检查模式的使用方法,帮助开发者避免隐式类型转换带来的问题。这些设计使JAX更适合高性能数值计算场景,同时保持了与Python生态的良好兼容性。原创 2025-10-17 01:47:34 · 679 阅读 · 0 评论 -
JAX - 并行编程简介
JAX并行编程简介 本教程介绍了JAX中SPMD(单程序多数据)并行计算的三种模式: 自动分片(jax.jit) - 编译器自动优化数据分片和计算策略 显式分片 - 用户指定分片方式作为类型的一部分,约束编译器行为 手动分片(shard_map) - 用户直接编写单设备代码,显式控制数据分片和通信 关键概念是数据分片,通过jax.Array和Sharding对象描述数据在设备间的分布。jax.jit能自动处理分片数据的并行计算,而shard_map允许用户编写针对单设备分片的函数。教程还介绍了如何创建分片数原创 2025-10-17 01:47:40 · 718 阅读 · 0 评论 -
JAX - 参与贡献 JAX
JAX 项目欢迎社区贡献,包括回答问题、改进文档、提交代码等。主要贡献流程:1) 签署Google CLA协议;2) Fork项目并创建开发分支;3) 本地安装测试环境;4) 确保通过lint检查与测试;5) 提交PR。注意事项:PR应为单一变更、通过完整CI测试,维护者会进行额外后端测试。项目遵循Google开源指南,推荐从标记"good first issue"的问题入手。原创 2025-10-16 01:34:41 · 1013 阅读 · 0 评论 -
JAX - 高级自动微分
摘要: 本教程深入探讨JAX中自动微分的高级应用,包括高阶导数计算和优化技巧。主要内容包括:1)使用jacfwd和jacrev计算Hessian矩阵;2)高阶优化如MAML的梯度更新微分;3)利用stop_gradient阻断特定梯度传播,实现TD(0)强化学习;4)直通估计器实现不可微函数的伪梯度;5)通过vmap高效计算样本级梯度。这些功能展示了JAX在自动微分方面的强大灵活性,使其成为复杂机器学习算法实现的理想工具。原创 2025-10-16 01:34:35 · 638 阅读 · 0 评论 -
JAX - 从源码构建
文章摘要:本文详细介绍了从源码构建JAX的完整步骤,包括获取源代码、构建jaxlib支持库以及安装Python包。重点说明了如何通过pip安装预编译的jaxlib或从源码编译,并提供了不同平台(Linux/Windows/MacOS)和硬件(CPU/GPU/TPU)的构建指导。特别强调了使用hermetic Python以确保构建的隔离性和可复现性,同时给出了指定Python版本、管理依赖以及处理XLA仓库修改的方法。对于CUDA、ROCM等特殊环境也提供了专门的构建说明。原创 2025-10-16 01:34:29 · 638 阅读 · 0 评论 -
JAX - Tracing(追踪)
JAX通过追踪机制(tracing)来优化函数执行,将Python函数转换为高效的XLA计算图。追踪过程中,函数参数被替换为仅包含形状和类型的tracer对象,生成jaxpr中间表示。静态值与追踪值的区分是关键,静态操作在编译时执行,而追踪操作在运行时由XLA处理。JAX变换(如jit、grad、vmap)会引入不同类型的tracer,影响控制流和数值计算。正确使用numpy和jax.numpy可以精确控制静态与追踪操作,是高效使用JAX编译的核心技巧。原创 2025-10-16 01:34:01 · 770 阅读 · 0 评论 -
JAX - 有状态计算
本文介绍了在JAX中处理有状态计算的策略。由于JAX要求函数是纯函数,不能有副作用,传统的有状态计算(如计数器、模型参数等)需要转换为显式状态传递的形式。作者通过计数器示例展示了如何将状态作为参数传递和返回,并应用于线性回归模型。这种函数式编程范式是JAX程序处理状态的核心方法,虽然增加了手动管理状态的复杂度,但保证了与JAX变换的兼容性。文章还提到可以通过JAX生态系统库来简化复杂神经网络的状态管理。原创 2025-10-16 01:33:54 · 524 阅读 · 0 评论 -
JAX - JIT下的控制流与逻辑操作
JIT下的控制流与逻辑操作 在JAX中,JIT编译时Python控制流和逻辑操作会受到限制。常规Python控制流无法在JIT环境下直接使用,因为它们依赖于运行时的具体值。JAX提供了结构化控制流原语(如lax.cond、lax.while_loop、lax.fori_loop和lax.scan)来解决这个问题,这些原语可以正确追踪并支持自动微分。对于逻辑操作,JAX提供了jnp.logical_and等函数,但要注意它们不会短路求值。通过static_argnames参数可以指定某些参数在编译时使用具体值原创 2025-10-16 01:33:47 · 709 阅读 · 0 评论 -
JAX - 快速开始:如何用 JAX 思考
JAX 快速入门指南 JAX 是一个高性能数值计算库,结合了 NumPy 风格接口、自动微分和 JIT 编译功能。本文介绍了 JAX 的核心特性: NumPy 兼容性:JAX 提供 jax.numpy 模块,与 NumPy API 高度兼容,但数组不可变。 JIT 编译:通过 jax.jit 装饰器可将函数编译为高效代码,显著提升性能,但要求数组形状静态可知。 自动微分:jax.grad 实现自动微分,可与 JIT 编译组合使用。 设备支持:JAX 数组可在 CPU/GPU/TPU 上运行,支持多设备分片并原创 2025-10-16 01:33:39 · 665 阅读 · 0 评论 -
JAX - 使用 pytrees
JAX中的pytree是一种支持嵌套结构的树形数据容器,如列表、元组和字典的任意组合。它广泛用于机器学习场景,如处理模型参数、数据集等。JAX提供jax.tree.map等实用函数操作pytree,支持对嵌套结构进行批量处理。文章还介绍了如何通过register_pytree_node注册自定义容器类型为pytree节点,使其支持JAX的树形操作。通过多层感知机(MLP)示例展示了pytree在实际模型训练中的应用,包括参数初始化、前向传播和梯度更新等操作。原创 2025-10-16 01:33:29 · 875 阅读 · 0 评论 -
JAX - 伪随机数(Pseudorandom numbers)
JAX与NumPy在伪随机数生成机制上存在显著差异。NumPy基于全局状态,使用梅森旋转法,保证顺序等价性;而JAX采用显式随机key设计,通过Threefry PRNG实现可分裂机制,支持并行化但牺牲了顺序等价性。JAX要求显式管理key的消耗与分裂,避免状态重复使用,确保可复现性和并行性。这种设计虽增加了使用复杂度,但更适合现代硬件加速和大规模并行计算场景。原创 2025-10-16 01:33:00 · 992 阅读 · 0 评论 -
JAX - 自动向量化(Automatic vectorization)
本文介绍了JAX中的自动向量化功能,通过jax.vmap()可以轻松实现函数的批量处理。相比手动编写循环或重写向量化代码,jax.vmap()能自动追踪函数并在输入前添加批量轴。文章展示了如何使用in_axes和out_axes参数控制批量轴位置,以及如何对部分参数进行批量化处理。此外,jax.vmap()还可以与其他JAX变换(如jax.jit())自由组合使用。这种自动向量化方法简化了批量计算,提高了代码效率。原创 2025-10-16 01:32:43 · 382 阅读 · 0 评论 -
JAX - 自动微分(Automatic differentiation)
JAX提供了强大的自动微分功能,支持梯度计算、高阶求导和复杂数据结构求导。通过jax.grad()可以轻松计算函数梯度,支持嵌套列表、元组和字典等PyTree结构。教程展示了在线性逻辑回归中应用自动微分,并介绍了同时计算函数值和梯度的jax.value_and_grad()方法。此外,还演示了如何用有限差分验证自动微分结果。JAX的自动微分系统使高阶导数计算变得简单,并支持PyTree自定义节点,为机器学习等应用提供了灵活高效的梯度计算工具。原创 2025-10-16 01:32:34 · 899 阅读 · 0 评论 -
JAX - 即时编译(Just-in-time compilation)
本文介绍了JAX的即时编译(JIT)机制及其工作原理。主要内容包括: JAX通过将函数转换为原语操作序列(jaxpr)实现转换,但不会捕获副作用,需使用纯函数。 使用jax.jit()可将Python函数编译为XLA高效代码,如SELU函数编译后速度提升4倍以上。 JIT编译的限制:不能直接编译依赖运行时值的条件分支和循环,需重写代码或使用控制流操作符。 对必须依赖输入值的条件,可使用static_argnums/static_argnames标记静态参数,但会为不同值重新编译。 调试技巧:print()在原创 2025-10-16 01:32:29 · 949 阅读 · 0 评论 -
JAX - 安装
JAX安装指南摘要:JAX支持CPU、NVIDIA GPU、Google TPU等多种硬件平台,安装方式因平台而异。CPU用户直接pip安装jax;NVIDIA GPU用户推荐通过"jax[cuda13]"自动安装CUDA;TPU用户使用"jax[tpu]"。Mac GPU(实验性)和AMD GPU(Linux)有特定安装要求。不同操作系统和架构的兼容性差异较大,Windows支持有限,Linux x86_64支持最全面。安装前需确认平台兼容性,并注意CUDA版本与驱原创 2025-10-16 01:32:19 · 994 阅读 · 0 评论 -
JAX 0.6.2 更新日志
JAX 0.6.1版本于2025年5月21日发布,主要更新包括:新增jax.lax.axis_size()函数获取映射轴大小;重新启用CUDA依赖版本检查,并支持夜间构建包安装;PartitionSpec不再继承元组,ShapeDtypeStruct改为不可变对象;弃用custom_jvp_call_jaxpr_p函数,将在0.7.0移除。这些变更涉及核心功能优化和API调整。原创 2025-10-15 07:56:34 · 172 阅读 · 0 评论 -
JAX 0.7.2 更新日志
JAX 0.7.2版本带来多项变更:移除了from_dlpack()对DLPack capsule的支持,最低NumPy版本要求升至2.0,内部改用TypedNdArray表示常量。修复了arr.view()和randint的数值分布问题。弃用了jax2tf.convert的非原生序列化参数,并调整了jax_pmap_no_rank_reduction的默认行为。这些变更可能影响现有代码的兼容性。原创 2025-10-15 07:58:09 · 286 阅读 · 0 评论 -
JAX 0.7.1 更新日志
JAX 0.7.1版本更新要点:新增Python 3.14/3.14t支持,扩展Mac平台自由线程版本;引入jax.set_mesh替代旧API,升级至CUDA 12.9构建;jax.lax.dot()新增通用点积功能;废弃jax.lax.zeros_like_array()等多项API,其中部分内部API未提供替代方案。该版本包含多项功能优化与清理。(142字)原创 2025-10-15 07:57:38 · 256 阅读 · 0 评论 -
JAX 0.7.0 更新日志
JAX 0.7.0版本带来多项重要更新:新增jax.P别名和jax.tree.reduce_associative()功能,改进数组索引参数。重大变更包括默认迁移至Shardy系统、自动微分切换为直接线性化、API重命名(如Layout改为Format)及移除部分废弃功能(如infeed/outfeed)。弃用说明涉及数据类型检查、特殊函数等多项内容,建议用户参考迁移指南调整代码。最低Python版本要求提升至3.11,将支持至2026年7月。原创 2025-10-15 07:57:06 · 381 阅读 · 0 评论 -
JAX 0.6.1 更新日志
JAX 0.6.1版本于2025年5月21日发布,主要更新包括:新增jax.lax.axis_size()函数获取映射轴大小;重新启用CUDA依赖版本检查,并支持夜间构建包安装;PartitionSpec不再继承元组,ShapeDtypeStruct改为不可变对象;弃用custom_jvp_call_jaxpr_p函数,将在0.7.0移除。这些变更涉及核心功能优化和API调整。原创 2025-10-15 07:55:34 · 149 阅读 · 0 评论 -
JAX 0.6.0 更新日志
JAX 0.6.0版本带来多项重要变更:移除了jax.numpy.array()对None参数的支持、废弃了多个配置选项和API(如jax.tree_util.build_tree()改为jax.tree.unflatten()),并调整了CUDA支持要求(最低CuDNN v9.8)。此外,改进了包装器处理逻辑,要求jax.jit()函数必须通过位置参数传递。还更新了包安装命名规范(改用短横线分隔),并废弃了多个内部API。该版本继续推进代码清理和标准化工作,为后续版本做准备。原创 2025-10-15 07:54:46 · 761 阅读 · 0 评论 -
jax 0.5.3 更新日志
JAX 0.5.3版本于2025年3月19日发布,主要包含两项新特性:1) 在jax.lax.dynamic_slice()及相关函数中新增allow_negative_indices选项,可减少代码体积;2) 在jax.random.categorical()中新增replace选项,支持不放回采样功能。该版本保持了与现有行为的兼容性。原创 2025-10-15 07:54:04 · 129 阅读 · 0 评论 -
jax 0.5.1 更新日志
jit 追踪缓存现在会根据输入的 NamedSharding 类型进行缓存键处理。此前,追踪缓存并未包含分片信息(虽然后续的 jit 缓存如 lowering 和 compilation 缓存包含了分片信息),因此两种不同类型但等价的分片不会重新追踪,现在则会。原文地址:https://docs.jax.dev/en/latest/changelog.html#jax-0-5-1-feb-24-2025。时,会因为分片类型改变而导致追踪缓存失效。在上述例子中,先调用。原创 2025-10-15 07:52:28 · 497 阅读 · 0 评论 -
jax 0.5.0 更新日志
JAX 0.5.0版本引入重大变更:采用工作量版本控制,默认启用PRNG key新语义,停止支持Mac x86平台。主要更新包括:提升NumPy/SciPy最低版本要求,优化einsum和solve函数,扩展FFT支持高维变换,新增FFI自定义状态支持。同时弃用部分API并移除已过时功能。该版本需要用户检查代码兼容性,特别是PRNG相关逻辑。原创 2025-10-15 07:51:43 · 308 阅读 · 0 评论
分享