JAX-中文文档-一-

JAX 中文文档(一)

原文:jax.readthedocs.io/en/latest/

开始入门

安装 JAX

原文:jax.readthedocs.io/en/latest/installation.html

使用 JAX 需要安装两个包:jax 是纯 Python 的跨平台库,jaxlib 包含编译的二进制文件,对于不同的操作系统和加速器需要不同的构建。

TL;DR 对于大多数用户来说,典型的 JAX 安装可能如下所示:

  • 仅限 CPU(Linux/macOS/Windows)

    pip install -U jax 
    
  • GPU(NVIDIA,CUDA 12)

    pip install -U "jax[cuda12]" 
    
  • TPU(Google Cloud TPU VM)

    pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 
    

支持的平台

下表显示了所有支持的平台和安装选项。检查您的设置是否受支持;如果显示“是”或“实验性”,请单击相应链接以了解更详细的 JAX 安装方法。

Linux,x86_64 Linux,aarch64 macOS,Intel x86_64,AMD GPU macOS,Apple Silicon,基于 ARM Windows,x86_64 Windows WSL2,x86_64
CPU
NVIDIA GPU 不适用 实验性
Google Cloud TPU 不适用 不适用 不适用 不适用 不适用
AMD GPU 实验性 不适用

| Apple GPU | 不适用 | 否 | 实验性 | 实验性 | 不适用 | 不适用 | ## CPU

pip 安装:CPU

目前,JAX 团队为以下操作系统和架构发布 jaxlib 轮子:

  • Linux,x86_64

  • Linux, aarch64

  • macOS,Intel

  • macOS,基于 Apple ARM

  • Windows,x86_64(实验性

要安装仅 CPU 版本的 JAX,可能对于在笔记本电脑上进行本地开发非常有用,您可以运行:

pip  install  --upgrade  pip
pip  install  --upgrade  jax 

在 Windows 上,如果尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。

其他操作系统和架构需要从源代码构建。在其他操作系统和架构上尝试 pip 安装可能导致 jaxlib 未能与 jax 一起安装(虽然 jax 可能成功安装,但在运行时可能会失败)。 ## NVIDIA GPU

JAX 支持具有 SM 版本 5.2(Maxwell)或更新版本的 NVIDIA GPU。请注意,由于 NVIDIA 在其软件中停止了对 Kepler 系列 GPU 的支持,JAX 不再支持 Kepler 系列 GPU。

您必须先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但驱动版本必须 >= 525.60.13 才能在 Linux 上运行 CUDA 12。

如果您需要在较老的驱动程序上使用更新的 CUDA 工具包,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 专门为此目的提供的 CUDA 向前兼容包

pip 安装:NVIDIA GPU(通过 pip 安装,更加简便)

有两种安装 JAX 并支持 NVIDIA GPU 的方式:

  • 使用从 pip 轮子安装的 NVIDIA CUDA 和 cuDNN

  • 使用自行安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheel 安装 CUDA 和 cuDNN,因为这样更加简单!

NVIDIA 仅为 x86_64 和 aarch64 平台发布了 CUDA pip 包;在其他平台上,您必须使用本地安装的 CUDA。

pip  install  --upgrade  pip

# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip  install  --upgrade  "jax[cuda12]" 

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,您需要检查以下几点:

  • 请确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 可能会覆盖 NVIDIA CUDA 库。

  • 确保安装的 NVIDIA CUDA 库与 JAX 请求的库相符。重新运行上述安装命令应该可以解决问题。

pip 安装:NVIDIA GPU(本地安装的 CUDA,更为复杂)

如果您想使用预安装的 NVIDIA CUDA 副本,您必须首先安装 NVIDIA 的 CUDA cuDNN

JAX 仅为 Linux x86_64 和 Linux aarch64 提供预编译的 CUDA 兼容 wheel。其他操作系统和架构的组合也可能存在,但需要从源代码构建(请参考构建指南以了解更多信息)。

您应该使用至少与您的NVIDIA CUDA toolkit 对应的驱动版本相同的 NVIDIA 驱动程序版本。例如,在无法轻易更新 NVIDIA 驱动程序的集群上需要使用更新的 CUDA 工具包,您可以使用 NVIDIA 为此目的提供的CUDA 向前兼容包

JAX 目前提供一种 CUDA wheel 变体:

Built with Compatible with
CUDA 12.3 CUDA >=12.1
CUDNN 9.0 CUDNN >=9.0, <10.0
NCCL 2.19 NCCL >=2.18

JAX 检查您的库的版本,如果版本不够新,则会报错。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用此检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅在执行多 GPU 计算时才需要。

安装方法如下:

pip  install  --upgrade  pip

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip  install  --upgrade  "jax[cuda12_local]" 

这些 pip 安装在 Windows 上无法工作,并可能静默失败;请参考上表。

您可以使用以下命令查找您的 CUDA 版本:

nvcc  --version 

JAX 使用 LD_LIBRARY_PATH 查找 CUDA 库,并使用 PATH 查找二进制文件(ptxasnvlink)。请确保这些路径指向正确的 CUDA 安装位置。

如果在使用预编译的 wheel 时遇到任何错误或问题,请在GitHub 问题跟踪器上告知 JAX 团队。

NVIDIA GPU Docker 容器

NVIDIA 提供了JAX 工具箱容器,这些是 bleeding edge 容器,包含 jax 的夜间版本和一些模型/框架。 ## Google Cloud TPU

pip 安装:Google Cloud TPU

JAX 为 Google Cloud TPU 提供预构建的安装包。要在云 TPU VM 中安装 JAX 及相应版本的 jaxliblibtpu,您可以运行以下命令:

pip  install  jax[tpu]  -f  https://storage.googleapis.com/jax-releases/libtpu_releases.html 

对于 Colab 的用户(https://colab.research.google.com/),请确保您使用的是 TPU v2 而不是已过时的旧 TPU 运行时。## Apple Silicon GPU(基于 ARM 的)

pip 安装:Apple 基于 ARM 的 Silicon GPU

Apple 为基于 ARM 的 GPU 硬件提供了一个实验性的 Metal 插件。详情请参阅 Apple 的 JAX on Metal 文档

注意: Metal 插件存在一些注意事项:

  • Metal 插件是新的实验性质,并存在一些已知问题,请在 JAX 问题跟踪器上报告任何问题。

  • 当前的 Metal 插件需要非常特定版本的 jaxjaxlib。随着插件 API 的成熟,此限制将逐步放宽。## AMD GPU

JAX 具有实验性的 ROCm 支持。有两种安装 JAX 的方法:

  • 使用 AMD 的 Docker 容器;或者

  • 从源代码构建(参见从源代码构建 —— 一个名为 Additional notes for building a ROCM jaxlib for AMD GPUs 的部分)。

Conda(社区支持)

Conda 安装

存在一个社区支持的 jax 的 Conda 构建。要使用 conda 安装它,只需运行:

conda  install  jax  -c  conda-forge 

要在带有 NVIDIA GPU 的机器上安装它,请运行:

conda  install  jaxlib=*=*cuda*  jax  cuda-nvcc  -c  conda-forge  -c  nvidia 

请注意,由 conda-forge 分发的 cudatoolkit 缺少 JAX 所需的 ptxas。因此,您必须从 nvidia 渠道安装 cuda-nvcc 包,或者在您的机器上单独安装 CUDA,以便 ptxas 在您的路径中可用。上述渠道顺序很重要(conda-forgenvidia 之前)。

如果您希望覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 版本,请按照 conda-forge 网站上“技巧和技巧”部分的说明操作。

前往 conda-forgejaxlibjax 存储库获取更多详细信息。

JAX 夜间安装

夜间版本反映了它们构建时主 JAX 存储库的状态,并且可能无法通过完整的测试套件。

  • 仅限 CPU:
pip  install  -U  --pre  jax  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html 
  • Google Cloud TPU:
pip  install  -U  --pre  jax[tpu]  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html  -f  https://storage.googleapis.com/jax-releases/libtpu_releases.html 
  • NVIDIA GPU(CUDA 12):
pip  install  -U  --pre  jax[cuda12]  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html 
  • NVIDIA GPU(CUDA 12)遗留:

用于历史 nightly 版本的单片 CUDA jaxlibs。您很可能不需要此选项;不会再构建更多的单片 CUDA jaxlibs,并且现有的将在 2024 年 9 月到期。请使用上面的“CUDA 12”选项。

pip  install  -U  --pre  jaxlib  -f  https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html 

从源代码构建 JAX

参考从源代码构建。

安装旧版本的 jaxlib wheels

由于 Python 软件包索引上的存储限制,JAX 团队定期从 http://pypi.org/project/jax 的发布中删除旧的jaxlib安装包。但是您仍然可以通过这里的 URL 直接安装它们。例如:

# Install jaxlib on CPU via the wheel archive
pip  install  jax[cpu]==0.3.25  -f  https://storage.googleapis.com/jax-releases/jax_releases.html

# Install the jaxlib 0.3.25 CPU wheel directly
pip  install  jaxlib==0.3.25  -f  https://storage.googleapis.com/jax-releases/jax_releases.html 

对于特定的旧 GPU 安装包,请确保使用jax_cuda_releases.html的 URL;例如

pip  install  jaxlib==0.3.25+cuda11.cudnn82  -f  https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 

快速入门

原文:jax.readthedocs.io/en/latest/quickstart.html

JAX 是一个面向数组的数值计算库(à la NumPy),具有自动微分和 JIT 编译功能,以支持高性能的机器学习研究

本文档提供了 JAX 主要功能的快速概述,让您可以快速开始使用 JAX:

  • JAX 提供了一个统一的类似于 NumPy 的接口,用于在 CPU、GPU 或 TPU 上运行的计算,在本地或分布式设置中。

  • JAX 通过 Open XLA 内置了即时编译(JIT)功能,这是一个开源的机器学习编译器生态系统。

  • JAX 函数支持通过其自动微分转换有效地评估梯度。

  • JAX 函数可以自动向量化,以有效地将它们映射到表示输入批次的数组上。

安装

可以直接从 Python Package Index 安装 JAX 用于 Linux、Windows 和 macOS 上的 CPU:

pip install jax 

或者,对于 NVIDIA GPU:

pip install -U "jax[cuda12]" 

如需更详细的特定平台安装信息,请查看安装 JAX。

JAX 就像 NumPy 一样

大多数 JAX 的使用是通过熟悉的 jax.numpy API 进行的,通常在 jnp 别名下导入:

import jax.numpy as jnp 

通过这个导入,您可以立即像使用典型的 NumPy 程序一样使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和操作符,以及数组属性和方法:

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x)) 
[0\.        1.05      2.1       3.1499999 4.2      ] 

一旦您开始深入研究,您会发现 JAX 数组和 NumPy 数组之间存在一些差异;这些差异在 🔪 JAX - The Sharp Bits 🔪 中进行了探讨。

使用jax.jit()进行即时编译

JAX 可以在 GPU 或 TPU 上透明运行(如果没有,则退回到 CPU)。然而,在上述示例中,JAX 是一次将核心分派到芯片上的操作。如果我们有一系列操作,我们可以使用 jax.jit() 函数将这些操作一起编译为 XLA。

我们可以使用 IPython 的 %timeit 快速测试我们的 selu 函数,使用 block_until_ready() 来考虑 JAX 的动态分派(请参阅异步分派):

from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready() 
2.84 ms ± 9.23 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 

(请注意,我们已经使用 jax.random 生成了一些随机数;有关如何在 JAX 中生成随机数的详细信息,请查看伪随机数)。

我们可以使用 jax.jit() 转换来加速此函数的执行,该转换将在首次调用 selu 时进行 JIT 编译,并在此后进行缓存。

from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready() 
844 μs ± 2.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

上述时间表示在 CPU 上执行,但同样的代码可以在 GPU 或 TPU 上运行,通常会有更大的加速效果。

欲了解更多关于 JAX 中 JIT 编译的信息,请查看即时编译。

使用 jax.grad() 计算导数

除了通过 JIT 编译转换函数外,JAX 还提供其他转换功能。其中一种转换是 jax.grad(),它执行自动微分 (autodiff)

from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small)) 
[0.25       0.19661197 0.10499357] 

让我们用有限差分来验证我们的结果是否正确。

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small)) 
[0.24998187 0.1965761  0.10502338] 

grad()jit() 转换可以任意组合并混合使用。在上面的示例中,我们对 sum_logistic 进行了 JIT 编译,然后取了它的导数。我们可以进一步进行:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) 
-0.0353256 

除了标量值函数外,jax.jacobian() 转换还可用于计算向量值函数的完整雅可比矩阵:

from jax import jacobian
print(jacobian(jnp.exp)(x_small)) 
[[1\.        0\.        0\.       ]
 [0\.        2.7182817 0\.       ]
 [0\.        0\.        7.389056 ]] 

对于更高级的自动微分操作,您可以使用 jax.vjp() 来进行反向模式向量-雅可比积分,以及使用 jax.jvp()jax.linearize() 进行正向模式雅可比-向量积分。这两者可以任意组合,也可以与其他 JAX 转换组合使用。例如,jax.jvp()jax.vjp() 用于定义正向模式 jax.jacfwd() 和反向模式 jax.jacrev(),用于计算正向和反向模式下的雅可比矩阵。以下是组合它们以有效计算完整 Hessian 矩阵的一种方法:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small)) 
[[-0\.         -0\.         -0\.        ]
 [-0\.         -0.09085776 -0\.        ]
 [-0\.         -0\.         -0.07996249]] 

这种组合在实践中产生了高效的代码;这基本上是 JAX 内置的 jax.hessian() 函数的实现方式。

想了解更多关于 JAX 中的自动微分,请查看自动微分。

使用 jax.vmap() 进行自动向量化

另一个有用的转换是 vmap(),即向量化映射。它具有沿数组轴映射函数的熟悉语义,但与显式循环函数调用不同,它将函数转换为本地向量化版本,以获得更好的性能。与 jit() 组合时,它可以与手动重写函数以处理额外批处理维度的性能相媲美。

我们将处理一个简单的示例,并使用 vmap() 将矩阵-向量乘法提升为矩阵-矩阵乘法。虽然在这种特定情况下手动完成这一点很容易,但相同的技术也适用于更复杂的函数。

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x) 

apply_matrix 函数将一个向量映射到另一个向量,但我们可能希望将其逐行应用于矩阵。在 Python 中,我们可以通过循环遍历批处理维度来实现这一点,但通常导致性能不佳。

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready() 
Naively batched
962 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

熟悉 jnp.dot 函数的程序员可能会意识到,可以重写 apply_matrix 来避免显式循环,利用 jnp.dot 的内置批处理语义:

import numpy as np

@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready() 
Manually batched
14.3 μs ± 28.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) 

然而,随着函数变得更加复杂,这种手动批处理变得更加困难且容易出错。vmap() 转换旨在自动将函数转换为支持批处理的版本:

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready() 
Auto-vectorized with vmap
21.7 μs ± 98.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) 

正如您所预期的那样,vmap() 可以与 jit()grad() 和任何其他 JAX 转换任意组合。

想了解更多关于 JAX 中的自动向量化,请查看自动向量化。

这只是 JAX 能做的一小部分。我们非常期待看到你用它做些什么!

🔪 JAX - 锋利的部分 🔪

原文:jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

在 Colab 中打开 在 Kaggle 中打开

levskaya@ mattjj@

在意大利乡间漫步时,人们会毫不犹豫地告诉您,JAX 具有 “una anima di pura programmazione funzionale”

JAX 是一种用于表达和组合数值程序转换的语言。JAX 还能够为 CPU 或加速器(GPU/TPU)编译数值程序。对于许多数值和科学程序,JAX 表现出色,但前提是它们必须按照我们下面描述的某些约束条件编写。

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp 

🔪 纯函数

JAX 的转换和编译设计仅适用于函数式纯的 Python 函数:所有输入数据通过函数参数传递,所有结果通过函数结果输出。纯函数如果以相同的输入调用,将始终返回相同的结果。

下面是一些函数示例,这些函数不是函数式纯的,因此 JAX 的行为与 Python 解释器不同。请注意,这些行为并不由 JAX 系统保证;正确使用 JAX 的方法是仅在函数式纯 Python 函数上使用它。

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.]))) 
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.] 
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.]))) 
First call:  4.0
Second call:  5.0
Third call, different type:  [14.] 
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value 
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

即使一个 Python 函数在内部实际上使用了有状态的对象,只要它不读取或写入外部状态,它就可以是函数式纯的:

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.)) 
50.0 

不建议在希望jit的任何 JAX 函数中使用迭代器或任何控制流原语。原因是迭代器是一个引入状态以检索下一个元素的 Python 对象。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,有一些尝试在 JAX 中使用迭代器的错误示例。其中大多数会返回错误,但有些会给出意外的结果。

import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error 
45
0 

🔪 原地更新

在 Numpy 中,您习惯于执行以下操作:

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array) 
original array:
[[0\. 0\. 0.]
 [0\. 0\. 0.]
 [0\. 0\. 0.]]
updated array:
[[0\. 0\. 0.]
 [1\. 1\. 1.]
 [0\. 0\. 0.]] 

然而,如果我们尝试在 JAX 设备数组上就地更新,我们会收到错误!(☉_☉)

%xmode Minimal 
Exception reporting mode: Minimal 
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0 
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html 

允许变量在原地变异会使程序分析和转换变得困难。JAX 要求程序是纯函数。

相反,JAX 提供了对 JAX 数组上的 .at 属性进行函数式数组更新

️⚠️ 在 jit 的代码中和 lax.while_looplax.fori_loop 中,切片的大小不能是参数 的函数,而只能是参数 形状 的函数 — 切片的起始索引没有此类限制。有关此限制的更多信息,请参阅下面的 控制流 部分。

数组更新:x.at[idx].set(y)

例如,上述更新可以写成:

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array) 
updated array:
 [[0\. 0\. 0.]
 [1\. 1\. 1.]
 [0\. 0\. 0.]] 

JAX 的数组更新函数与其 NumPy 版本不同,是在原地外执行的。也就是说,更新后的数组作为新数组返回,原始数组不会被更新修改。

print("original array unchanged:\n", jax_array) 
original array unchanged:
 [[0\. 0\. 0.]
 [0\. 0\. 0.]
 [0\. 0\. 0.]] 

然而,在jit编译的代码内部,如果x.at[idx].set(y)输入值 x 没有被重用,编译器会优化数组更新以进行原地操作。

使用其他操作的数组更新

索引数组更新不仅限于覆盖值。例如,我们可以进行索引加法如下:

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array) 
original array:
[[1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 1\. 1\. 1.]]
new array post-addition:
[[1\. 1\. 1\. 8\. 8\. 8.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 8\. 8\. 8.]
 [1\. 1\. 1\. 1\. 1\. 1.]
 [1\. 1\. 1\. 8\. 8\. 8.]] 

有关索引数组更新的更多详细信息,请参阅.at属性的文档

🔪 超出边界索引

在 NumPy 中,当您索引数组超出其边界时,通常会抛出错误,例如:

np.arange(10)[11] 
IndexError: index 11 is out of bounds for axis 0 with size 10 

然而,在加速器上运行的代码中引发错误可能会很困难或不可能。因此,JAX 必须为超出边界的索引选择一些非错误行为(类似于无效的浮点算术结果为NaN的情况)。当索引操作是数组索引更新时(例如index_add或类似的原语),将跳过超出边界的索引;当操作是数组索引检索时(例如 NumPy 索引或类似的原语),索引将夹紧到数组的边界,因为必须返回某些内容。例如,数组的最后一个值将从此索引操作中返回:

jnp.arange(10)[11] 
Array(9, dtype=int32) 

如果您希望对超出边界索引的行为有更精细的控制,可以使用ndarray.at的可选参数;例如:

jnp.arange(10.0).at[11].get() 
Array(9., dtype=float32) 
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) 
Array(nan, dtype=float32) 

注意由于这种索引检索行为,像jnp.nanargminjnp.nanargmax这样的函数在由 NaN 组成的切片中返回-1,而 NumPy 会抛出错误。

还请注意,由于上述两种行为不是互为反操作,反向模式自动微分(将索引更新转换为索引检索及其反之)将不会保留超出边界索引的语义。因此,将 JAX 中的超出边界索引视为未定义行为可能是个好主意。

🔪 非数组输入:NumPy vs. JAX

NumPy 通常可以接受 Python 列表或元组作为其 API 函数的输入:

np.sum([1, 2, 3]) 
np.int64(6) 

JAX 在这方面有所不同,通常会返回有用的错误:

jnp.sum([1, 2, 3]) 
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0. 

这是一个有意的设计选择,因为向追踪函数传递列表或元组可能导致性能下降,而这种性能下降可能很难检测到。

例如,请考虑允许列表输入的jnp.sum的以下宽松版本:

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x) 
Array(45, dtype=int32) 

输出与预期相符,但这隐藏了底层的潜在性能问题。在 JAX 的追踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为单独的 JAX 变量,并分别处理和推送到设备。这可以在上面的permissive_sum函数的 jaxpr 中看到:

make_jaxpr(permissive_sum)(x) 
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
    v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
    w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
    x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
    y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
    z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
    ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
    bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
    bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
    bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
    be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) } 

列表的每个条目都作为单独的输入处理,导致追踪和编译开销随列表大小线性增长。为了避免这样的意外,JAX 避免将列表和元组隐式转换为数组。

如果您希望将元组或列表传递给 JAX 函数,可以首先显式地将其转换为数组:

jnp.sum(jnp.array(x)) 
Array(45, dtype=int32) 

🔪 随机数

如果所有因糟糕的rand()而存疑的科学论文都从图书馆书架上消失,每个书架上会有一个拳头大小的空白。 - Numerical Recipes

RNG 和状态

您习惯于从 numpy 和其他库中使用有状态的伪随机数生成器(PRNG),这些库在幕后巧妙地隐藏了许多细节,为您提供了伪随机性的丰富源泉:

print(np.random.random())
print(np.random.random())
print(np.random.random()) 
0.9818293835329528
0.06574727326903418
0.3930007618911092 

在底层,numpy 使用Mersenne Twister PRNG 来驱动其伪随机函数。该 PRNG 具有(2^{19937}-1)的周期,并且在任何时候可以由624 个 32 位无符号整数和一个表示已使用的“熵”量的位置来描述。

np.random.seed(0)
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
#       2481403966, 4042607538,  337614300, ... 614 more numbers...,
#       3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) 

这个伪随机状态向量在每次需要随机数时都会在幕后自动更新,“消耗”Mersenne Twister 状态向量中的 2 个 uint32:

_ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
#      4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) 

魔法 PRNG 状态的问题在于很难推断它在不同线程、进程和设备中的使用和更新方式,并且在熵的生成和消耗细节对最终用户隐藏时,非常容易出错。

Mersenne Twister PRNG 也被认为存在一些问题,它具有较大的 2.5kB 状态大小,导致初始化问题很多。它在现代的 BigCrush 测试中失败,并且通常速度较慢。

JAX PRNG

相反,JAX 实现了一个显式的PRNG,其中熵的生成和消耗通过显式传递和迭代 PRNG 状态来处理。JAX 使用一种现代化的Threefry 基于计数器的 PRNG,它是可分裂的。也就是说,其设计允许我们将 PRNG 状态分叉成新的 PRNG,以用于并行随机生成。

随机状态由一个我们称之为密钥的特殊数组元素描述:

from jax import random
key = random.key(0)
key 
Array((), dtype=key<fry>) overlaying:
[0 0] 

JAX 的随机函数从 PRNG 状态生成伪随机数,但不会改变状态!

复用相同的状态会导致悲伤单调,剥夺最终用户生命力的混乱

print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key) 
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0] 

相反,我们分割PRNG 以在每次需要新的伪随机数时获得可用的子密钥

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom) 
old key Array((), dtype=key<fry>) overlaying:
[0 0]
    \---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389] 

我们传播密钥并在需要新的随机数时生成新的子密钥

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom) 
old key Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
    \---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055] 

我们可以同时生成多个子密钥

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,))) 
[-0.37533438]
[0.98645043]
[0.14553197] 

🔪 控制流

✔ python 控制流 + 自动微分 ✔

如果您只想将grad应用于您的 Python 函数,可以使用常规的 Python 控制流结构,没有问题,就像使用Autograd(或 Pytorch 或 TF Eager)一样。

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok! 
12.0
-4.0 

python 控制流 + JIT

使用jit进行控制流更为复杂,默认情况下具有更多约束。

这个可以工作:

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3)) 
24 

这样也可以:

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.]))) 
6.0 

但默认情况下,这样不行:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2) 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_1227/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 

怎么回事!?

当我们jit编译一个函数时,通常希望编译一个适用于许多不同参数值的函数版本,以便我们可以缓存和重复使用编译代码。这样我们就不必在每次函数评估时重新编译。

例如,如果我们在数组jnp.array([1., 2., 3.], jnp.float32)上评估@jit函数,我们可能希望编译代码,以便我们可以重复使用它来在jnp.array([4., 5., 6.], jnp.float32)上评估函数,从而节省编译时间。

要查看适用于许多不同参数值的 Python 代码视图,JAX 会跟踪抽象值,这些抽象值表示可能输入集合的集合。有关不同的转换使用不同的抽象级别,详见多个不同的抽象级别

默认情况下,jit会在ShapedArray抽象级别上跟踪您的代码,其中每个抽象值表示具有固定形状和 dtype 的所有数组值的集合。例如,如果我们使用抽象值ShapedArray((3,), jnp.float32)进行跟踪,我们会得到可以重复使用于相应数组集合中的任何具体值的函数视图。这意味着我们可以节省编译时间。

但这里有一个权衡:如果我们在ShapedArray((), jnp.float32)上跟踪 Python 函数,它不专注于具体值,当我们遇到像if x < 3这样的行时,表达式x < 3会评估为表示集合{True, False}的抽象ShapedArray((), jnp.bool_)。当 Python 尝试将其强制转换为具体的TrueFalse时,我们会收到错误:我们不知道应该选择哪个分支,无法继续跟踪!权衡是,使用更高级别的抽象,我们获得 Python 代码的更一般视图(因此节省重新编译的时间),但我们需要更多约束来完成跟踪。

好消息是,您可以自行控制这种权衡。通过启用jit对更精细的抽象值进行跟踪,您可以放宽跟踪约束。例如,使用jitstatic_argnums参数,我们可以指定在某些参数的具体值上进行跟踪。下面是这个例子函数:

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.)) 
12.0 

下面是另一个例子,这次涉及循环:

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2) 
Array(5., dtype=float32) 

实际上,循环被静态展开。JAX 也可以在更高的抽象级别进行追踪,比如 Unshaped,但目前对于任何变换来说这都不是默认的。

️⚠️ 具有参数-值相关形状的函数

这些控制流问题也以更微妙的方式出现:我们希望 jit 的数值函数不能根据参数 来特化内部数组的形状(在参数 形状 上特化是可以的)。举个简单的例子,让我们创建一个函数,其输出恰好依赖于输入变量 length

def example_fun(length, val):
  return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4)) 
[4\. 4\. 4\. 4\. 4.] 
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4) 
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_1227/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length. 
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4)) 
[4\. 4\. 4\. 4\. 4\. 4\. 4\. 4\. 4\. 4.]
[4\. 4\. 4\. 4\. 4.] 

如果在我们的示例中 length 很少更改,那么 static_argnums 就会很方便,但如果它经常更改,那将是灾难性的!

最后,如果您的函数具有全局副作用,JAX 的追踪器可能会导致一些奇怪的事情发生。一个常见的坑是尝试在 jit 函数中打印数组:

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2) 
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 
Array(4, dtype=int32, weak_type=True) 

结构化控制流原语

JAX 中有更多控制流选项。假设您想避免重新编译但仍想使用可追踪的控制流,并避免展开大循环。那么您可以使用这四个结构化的控制流原语:

  • lax.cond 可微分

  • lax.while_loop 前向模式可微分

  • lax.fori_loop 前向模式可微分;如果端点是静态的,则前向和反向模式均可微分

  • lax.scan 可微分

cond

python 等效:

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand) 
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32) 
Array([-1.], dtype=float32) 

jax.lax 还提供了另外两个函数,允许根据动态谓词进行分支:

  • lax.select 类似于 lax.cond 的批处理版本,选择项表达为预先计算的数组而不是函数。

  • lax.switch 类似于 lax.cond,但允许在任意数量的可调用选项之间进行切换。

另外,jax.numpy 提供了几个 numpy 风格的接口:

  • jnp.where 的三个参数是 lax.select 的 numpy 风格封装。

  • jnp.piecewiselax.switch 的 numpy 风格封装,但是根据一系列布尔条件而不是单个标量索引进行切换。

  • jnp.select 的 API 类似于 jnp.piecewise,但选择项是作为预先计算的数组而不是函数给出的。它是基于多次调用 lax.select 实现的。

while_loop

python 等效:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val 
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32) 
Array(10, dtype=int32, weak_type=True) 

fori_loop

python 等效:

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val 
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32) 
Array(45, dtype=int32, weak_type=True) 

总结

[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \ \hline \ \textrm{if} & ❌ & ✔ \ \textrm{for} & ✔* & ✔\ \textrm{while} & ✔* & ✔\ \textrm{lax.cond} & ✔ & ✔\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\ \textrm{lax.scan} & ✔ & ✔\ \hline \end{array} \end{split}]

(\ast) = 参数--独立循环条件 - 展开循环

🔪 动态形状

在像jax.jitjax.vmapjax.grad等变换中使用的 JAX 代码要求所有输出数组和中间数组具有静态形状:即形状不能依赖于其他数组中的值。

例如,如果您正在实现自己的版本jnp.nansum,您可能会从以下内容开始:

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum() 

在 JIT 和其他转换之外,这可以正常工作:

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x)) 
10.0 

如果尝试将jax.jit或另一个转换应用于此函数,则会报错:

jax.jit(nansum)(x) 
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError 

问题在于x_without_nans的大小取决于x中的值,这另一种方式说它的大小是动态的。通常在 JAX 中,可以通过其他方式绕过对动态大小数组的需求。例如,在这里可以使用jnp.where的三参数形式,将 NaN 值替换为零,从而计算出相同的结果,同时避免动态形状:

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x)) 
10.0 

在其他情况下,类似的技巧可以发挥作用,其中动态形状数组出现。

🔪 NaNs

调试 NaNs

如果要追踪你的函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:

  • 设置JAX_DEBUG_NANS=True环境变量;

  • 在你的主文件顶部添加jax.config.update("jax_debug_nans", True)

  • 在你的主文件中添加jax.config.parse_flags_with_absl(),然后使用命令行标志设置选项,如--jax_debug_nans=True

这将导致 NaN 产生时立即终止计算。打开此选项会在由 XLA 产生的每个浮点类型值上添加 NaN 检查。这意味着对于不在@jit下的每个基元操作,值将被拉回主机并作为 ndarrays 进行检查。对于在@jit下的代码,将检查每个@jit函数的输出,如果存在 NaN,则将以逐个操作的去优化模式重新运行函数,有效地一次移除一个@jit级别。

可能会出现棘手的情况,比如只在@jit下出现的 NaN,但在去优化模式下却不会产生。在这种情况下,你会看到警告消息打印出来,但你的代码将继续执行。

如果在梯度评估的反向传递中产生 NaNs,当在堆栈跟踪中引发异常时,您将位于 backward_pass 函数中,这本质上是一个简单的 jaxpr 解释器,以反向遍历原始操作序列。在下面的示例中,我们使用命令行env JAX_DEBUG_NANS=True ipython启动了一个 ipython repl,然后运行了以下命令:

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return Array(device_buffer, *result_shape)

FloatingPointError: invalid value 

捕获到生成的 NaN。通过运行%debug,我们可以获得后期调试器。正如下面的示例所示,这也适用于在@jit下的函数。

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ... 

当此代码在 @jit 函数的输出中看到 NaN 时,它调用去优化的代码,因此我们仍然可以获得清晰的堆栈跟踪。我们可以使用 %debug 运行事后调试器来检查所有值,以找出错误。

⚠️ 如果您不是在调试,就不应该开启 NaN 检查器,因为它可能会导致大量设备主机往返和性能回归!

⚠️ NaN 检查器在 pmap 中不起作用。要调试 pmap 代码中的 NaN,可以尝试用 vmap 替换 pmap

🔪 双精度(64 位)

目前,默认情况下,JAX 强制使用单精度数字,以减少 Numpy API 将操作数过度提升为 double 的倾向。这是许多机器学习应用程序的期望行为,但可能会让您感到意外!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype 
/tmp/ipykernel_1227/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) 
dtype('float32') 

要使用双精度数,您需要在启动时设置 jax_enable_x64 配置变量**。

有几种方法可以做到这一点:

  1. 您可以通过设置环境变量 JAX_ENABLE_X64=True 来启用 64 位模式。

  2. 您可以在启动时手动设置 jax_enable_x64 配置标志:

    # again, this only works on startup!
    import jax
    jax.config.update("jax_enable_x64", True) 
    
  3. 您可以使用 absl.app.run(main) 解析命令行标志

    import jax
    jax.config.config_with_absl() 
    
  4. 如果您希望 JAX 为您运行 absl 解析,即您不想执行 absl.app.run(main),您可以改用

    import jax
    if __name__ == '__main__':
      # calls jax.config.config_with_absl() *and* runs absl parsing
      jax.config.parse_flags_with_absl() 
    

请注意,#2-#4 适用于任何 JAX 的配置选项。

然后,我们可以确认已启用 x64 模式:

import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64') 
/tmp/ipykernel_1227/2819792939.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) 
dtype('float32') 

注意事项

⚠️ XLA 不支持所有后端的 64 位卷积!

🔪 NumPy 中的各种分歧

虽然 jax.numpy 尽力复制 numpy API 的行为,但确实存在一些边界情况,其行为有所不同。许多这样的情况在前面的部分中有详细讨论;这里我们列出了几个已知的其他 API 分歧处。

  • 对于二进制操作,JAX 的类型提升规则与 NumPy 略有不同。有关更多详细信息,请参阅类型提升语义

  • 在执行不安全类型转换(即目标 dtype 不能表示输入值的转换)时,JAX 的行为可能依赖于后端,并且通常可能与 NumPy 的行为不同。NumPy 允许通过 casting 参数(参见np.ndarray.astype)控制这些情况下的结果;JAX 不提供任何此类配置,而是直接继承XLA:ConvertElementType的行为。

    这是一个示例,显示了在 NumPy 和 JAX 之间存在不同结果的不安全转换:

    >>> np.arange(254.0, 258.0).astype('uint8')
    array([254, 255,   0,   1], dtype=uint8)
    
    >>> jnp.arange(254.0, 258.0).astype('uint8')
    Array([254, 255, 255, 255], dtype=uint8) 
    

    这种不匹配通常在将浮点值转换为整数类型或反之时出现极端情况。

结束。

如果这里没有涉及到您曾经因之而哭泣和咬牙切齿的问题,请告知我们,我们将扩展这些介绍性建议

JAX 常见问题解答(FAQ)

原文:jax.readthedocs.io/en/latest/faq.html

我们在这里收集了一些经常被问到的问题的答案。欢迎贡献!

jit改变了我的函数行为

如果你有一个在使用jax.jit()后改变行为的 Python 函数,也许你的函数使用了全局状态或具有副作用。在下面的代码中,impure_func使用了全局变量y并由于print而具有副作用:

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y)) 

没有jit时的输出是:

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4 

并且使用jit时:

Inside: 0
Result: 0
Result: 1
Result: 2 

对于jax.jit(),函数在 Python 解释器中执行一次,此时发生Inside打印,并观察到y的第一个值。然后,函数被编译并缓存,以不同的x值多次执行,但y的第一个值相同。

更多阅读:

jit改变了输出的精确数值

有时候,用户会对用jit()包装一个函数后,函数的输出发生变化感到惊讶。例如:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365 
>>> print(jit(f)(x))
0.5723649 

这种输出的细微差异来自于 XLA 编译器中的优化:在编译过程中,XLA 有时会重新排列或省略某些操作,以使整体计算更加高效。

在这种情况下,XLA 利用对数的性质将log(sqrt(x))替换为0.5 * log(x),这是一个数学上相同的表达式,可以比原始表达式更有效地计算。输出的差异来自于浮点数运算只是对真实数学的近似,因此计算相同表达式的不同方式可能会有细微的差异。

其他时候,XLA 的优化可能导致更加显著的差异。考虑以下例子:

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf 
>>> print(jit(f)(x))
100.0 

在非 JIT 编译的逐操作模式下,结果为inf,因为jnp.exp(x)溢出并返回inf。然而,在 JIT 模式下,XLA 认识到logexp的反函数,并从编译函数中移除这些操作,简单地返回输入。在这种情况下,JIT 编译产生了对真实结果更准确的浮点数近似。

遗憾的是,XLA 的代数简化的完整列表文档不是很好,但如果你熟悉 C++ 并且对 XLA 编译器进行的优化类型感兴趣,你可以在源代码中查看它们:algebraic_simplifier.cc。## jit修饰函数编译速度非常慢

如果你的jit修饰函数在第一次调用时需要数十秒(甚至更长时间!)来运行,但在后续调用时执行速度很快,那么 JAX 正在花费很长时间来追踪或编译你的代码。

这通常表明调用你的函数生成了大量 JAX 内部表示的代码,通常是因为它大量使用了 Python 控制流,比如for循环。对于少量循环迭代,Python 是可以接受的,但如果你需要许多循环迭代,你应该重写你的代码,利用 JAX 的结构化控制流原语(如lax.scan())或避免用jit包装循环(你仍然可以在循环内部使用jit装饰的函数)。

如果你不确定问题出在哪里,你可以尝试在你的函数上运行jax.make_jaxpr()。如果输出很长,可能会导致编译速度慢。

有时候不明显如何重写你的代码以避免 Python 循环,因为你的代码使用了多个形状不同的数组。在这种情况下推荐的解决方案是利用像jax.numpy.where()这样的函数,在具有固定形状的填充数组上进行计算。

如果你的函数由于其他原因编译速度很慢,请在 GitHub 上提一个问题。## 如何在方法中使用 jit

大多数jax.jit()的示例涉及装饰独立的 Python 函数,但在类内部装饰方法会引入一些复杂性。例如,请考虑以下简单的类,我们在方法上使用了标准的jit()注解:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y 

然而,当你尝试调用此方法时,这种方法将导致错误:

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type. 

问题在于函数的第一个参数是self,其类型为CustomClass,而 JAX 不知道如何处理这种类型。在这种情况下,我们可能会使用三种基本策略,并在下面讨论它们。

策略 1: JIT 编译的辅助函数

最直接的方法是在类外部创建一个辅助函数,可以像平常一样进行 JIT 装饰。例如:

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y 

结果将按预期工作:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6 

这种方法的好处是简单、明确,避免了教 JAX 如何处理CustomClass类型对象的需要。但是,你可能希望将所有方法逻辑保留在同一个地方。

策略 2: 将self标记为静态

另一种常见模式是使用static_argnumsself参数标记为静态。但是必须小心,以避免意外的结果。你可能会简单地这样做:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y 

如果你调用该方法,它将不再引发错误:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6 

然而,有一个问题:如果在第一次方法调用后修改对象,则后续方法调用可能会返回不正确的结果:

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6 

为什么会这样?当你将对象标记为静态时,它将有效地被用作 JIT 内部编译缓存中的字典键,这意味着其哈希值(即 hash(obj) )、相等性(即 obj1 == obj2 )和对象身份(即 obj1 is obj2 )的行为应保持一致。自定义对象的默认 __hash__ 是其对象 ID,因此 JAX 无法知道突变对象应触发重新编译。

你可以通过为对象定义适当的 __hash____eq__ 方法来部分解决这个问题;例如:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul)) 

(参见object.__hash__() 的文档,进一步讨论在覆盖 __hash__ 时的要求)。

只要你不修改对象,这种方法与 JIT 和其他转换一起工作正常。将对象用作哈希键的突变会导致几个微妙的问题,这就是为什么例如可变 Python 容器(如dictlist)不定义 __hash__,而它们的不可变对应物(如tuple)会。

如果你的类依赖于原地突变(比如在其方法中设置 self.attr = ...),那么你的对象并非真正“静态”,将其标记为静态可能会导致问题。幸运的是,对于这种情况还有另一种选择。

策略 3:将 CustomClass 设为 PyTree

JIT 编译类方法的最灵活方法是将类型注册为自定义的 PyTree 对象;请参阅扩展 pytrees。这样可以明确指定类的哪些组件应视为静态,哪些应视为动态。以下是具体操作:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten) 

这当然更加复杂,但解决了上述简单方法所带来的所有问题:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6 

只要你的 tree_flattentree_unflatten 函数能正确处理类中所有相关属性,你应该能直接将这种类型的对象用作 JIT 编译函数的参数,而不需要任何特殊的注释。 ## 控制数据和计算在设备上的放置

让我们先来看看 JAX 中数据和计算放置的原则。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1)数据所在的设备;2)数据是否已提交到设备(有时称为数据对设备的粘性)。

默认情况下,JAX 数组被放置在默认设备上未提交状态 (jax.devices()[0]),这通常是第一个 GPU 或 TPU。如果没有 GPU 或 TPU 存在,jax.devices()[0] 是 CPU。可以通过 jax.default_device() 上下文管理器临时覆盖默认设备,或者通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 来设置整个进程的默认设备为 "cpu"、"gpu" 或 "tpu"(JAX_PLATFORMS 也可以是一个平台列表,指定优先顺序中可用的平台)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)} 

对涉及未提交数据的计算将在默认设备上执行,并且结果也会在默认设备上保持未提交状态。

数据也可以使用带有 device 参数的 jax.device_put() 明确地放置到设备上,在这种情况下,数据将会 提交 到设备上:

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)} 

包含一些已提交输入的计算将在已提交的设备上执行,并且结果将在同一设备上提交。在已提交到多个设备上的参数上调用操作将会引发错误。

也可以在没有 device 参数的情况下使用 jax.device_put()。如果数据已经在设备上(无论是已提交还是未提交状态),则保持不变。如果数据不在任何设备上,即它是常规的 Python 或 NumPy 值,则将其放置在默认设备上未提交状态。

经过 JIT 编译的函数行为与任何其他基本操作相同——它们会跟随数据,并且如果在提交到多个设备上的数据上调用时将会报错。

(在 2021 年 3 月之前的 PR #6002 中,创建数组常量时存在一些懒惰,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似的操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动。但为了简化实现,这种优化被移除了。)

(截至 2020 年 4 月,jax.jit() 函数有一个影响设备放置的 device 参数。该参数是实验性的,可能会被移除或更改,并且不建议使用。)

对于一个详细的例子,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

你刚刚将一个复杂的函数从 NumPy/SciPy 移植到 JAX。那真的加快了速度吗?

当使用 JAX 测量代码速度时,请记住与 NumPy 的这些重要差异:

  1. JAX 代码是即时编译(JIT)的。大多数使用 JAX 编写的代码可以以支持 JIT 编译的方式编写,这可以使其运行 更快(参见 To JIT or not to JIT)。为了从 JAX 中获得最大的性能,应在最外层的函数调用上应用 jax.jit()

    请记住,第一次运行 JAX 代码时,它会更慢,因为它正在被编译。即使在您自己的代码中不使用 jit,因为 JAX 的内置函数也是 JIT 编译的,这也是真实的。

  2. JAX 具有异步分派。 这意味着您需要调用 .block_until_ready() 来确保计算实际发生了(参见异步分派)。

  3. JAX 默认只使用 32 位数据类型。 您可能希望在 NumPy 中明确使用 32 位数据类型,或者在 JAX 中启用 64 位数据类型(参见Double (64 bit) precision)以进行公平比较。

  4. 在 CPU 和加速器之间传输数据需要时间。 如果您只想测量评估函数所需的时间,您可能希望先将数据传输到要运行的设备上(参见控制数据和计算放置在设备上)。

下面是一个将所有这些技巧放在一起进行微基准测试以比较 JAX 和 NumPy 的示例,利用 IPython 的便捷的 %time 和 %timeit 魔法命令

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime 

当在 Colab 上使用 GPU 运行时,我们看到:

  • NumPy 在 CPU 上每次评估需要 16.2 毫秒。

  • JAX 将 NumPy 数组复制到 GPU 上花费了 1.26 毫秒。

  • JAX 编译该函数需要 193 毫秒。

  • JAX 在 GPU 上每次评估需要 485 微秒。

在这种情况下,我们看到一旦数据传输完毕并且函数被编译,JAX 在 GPU 上重复评估时大约快了 30 倍。

这个比较公平吗?也许是。最终重要的性能是运行完整应用程序时的性能,这些应用程序不可避免地包含了一些数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 操作符执行矩阵乘法),以摊销 JAX/加速器相对于 NumPy/CPU 增加的开销。例如,如果我们将这个例子切换到使用 10x10 的输入,JAX/GPU 的运行速度比 NumPy/CPU 慢 10 倍(100 µs vs 10 µs)。

JAX 是否比 NumPy 更快?

用户经常试图通过这样的基准测试来回答一个问题,即 JAX 是否比 NumPy 更快;由于这两个软件包的差异,这并不是一个简单的问题。

广义上讲:

  • NumPy 操作是急切地、同步地执行,只在 CPU 上执行。

  • JAX 操作可以被急切地执行或者在编译之后执行(如果在 jit() 内部);它们被异步地分派(参见异步分派);并且它们可以在 CPU、GPU 或 TPU 上执行,每种设备都有非常不同且不断演变的性能特征。

这些架构差异使得直接比较 NumPy 和 JAX 的基准测试变得困难。

另外,这些差异还导致了软件包在工程上的不同关注点:例如,NumPy 大力减少了单个数组操作的每次调用分派开销,因为在 NumPy 的计算模型中,这种开销是无法避免的。另一方面,JAX 有几种方法可以避免分派开销(例如,JIT 编译、异步分派、批处理转换等),因此减少每次调用的开销并不是一个首要任务。

综上所述,在总结时:如果您在 CPU 上进行单个数组操作的微基准测试,通常可以预期 NumPy 由于其较低的每次操作分派开销而胜过 JAX。如果您在 GPU 或 TPU 上运行代码,或者在 CPU 上进行更复杂的 JIT 编译操作序列的基准测试,通常可以预期 JAX 胜过 NumPy。##不同类型的 JAX 值

在转换函数过程中,JAX 会用特殊的追踪器值替换一些函数参数。

如果您使用print语句,您可以看到这一点:

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.) 

上述代码确实返回了正确的值1.,但它还打印出了Traced<ShapedArray(float32[])>作为x的值。通常情况下,JAX 在内部以透明的方式处理这些追踪器值,例如,在用于实现jax.numpy函数的数值 JAX 原语中。这就是为什么在上面的例子中jnp.cos能够正常工作的原因。

更确切地说,追踪器值用于 JAX 变换函数的参数,除了由jax.jit()的特殊参数(如static_argnums)或jax.pmap()static_broadcasted_argnums标识的参数。通常,涉及至少一个追踪器值的计算将产生一个追踪器值。除了追踪器值之外,还有常规Python 值:在 JAX 变换之外计算的值,或者来自上述特定 JAX 变换的静态参数,或者仅仅是来自其他常规 Python 值的计算。在缺少 JAX 变换的情况下,这些值在任何地方都可以使用。

一个追踪器值携带一个抽象值,例如,ShapedArray包含有关数组形状和 dtype 的信息。我们将这些追踪器称为抽象追踪器。一些追踪器,例如,为自动微分变换的参数引入的那些,携带包含实际数组数据的ConcreteArray抽象值,并且用于解析条件。我们将这些追踪器称为具体追踪器。从这些具体追踪器计算出的追踪器值,也许与常规值结合,会产生具体追踪器。具体值是指常规值或具体追踪器。

大多数情况下,从追踪值计算得到的值本身也是追踪值。只有极少数例外情况,当一个计算可以完全使用追踪器携带的抽象值时,其结果可以是常规值。例如,使用 ShapedArray 抽象值获取追踪器的形状。另一个例子是显式地将具体的追踪器值转换为常规类型,例如 int(x)x.astype(float)。另一种情况是对 bool(x) 的处理,在具体性允许的情况下会产生 Python 布尔值。由于这种情况在控制流中经常出现,所以这种情况尤为显著。

下面是这些转换如何引入抽象或具体追踪器的说明:

  • jax.jit():除了由 static_argnums 指定的位置参数之外,为所有位置参数引入抽象追踪器,这些参数保持为常规值。

  • jax.pmap():除了由 static_broadcasted_argnums 指定的位置参数之外,为所有位置参数引入抽象追踪器

  • jax.vmap()jax.make_jaxpr()xla_computation():为所有位置参数引入抽象追踪器

  • jax.jvp()jax.grad() 为所有位置参数引入具体追踪器。唯一的例外是当这些转换在外部转换内部进行时,实际参数本身就是抽象追踪器时,由自动微分转换引入的追踪器也是抽象追踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时引入抽象追踪器,无论是否存在 JAX 转换。

当您的代码仅能操作常规的 Python 值时,例如基于数据的条件控制流的代码时,这些都是相关的:

def divide(x, y):
  return x / y if y >= 1. else 0. 

如果我们想要应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为常规值。这是由于布尔表达式 y >= 1.,它需要具体的值(常规或追踪器)。如果我们显式地编写 bool(y >= 1.)int(y)float(y),也会发生同样的情况。

有趣的是,jax.grad(divide)(3., 2.) 是有效的,因为 jax.grad() 使用具体追踪器,并使用 y 的具体值解析条件。 ## 缓冲捐赠

当 JAX 执行计算时,它使用设备上的缓冲区来处理所有输入和输出。如果您知道某个输入在计算后不再需要,并且它与某个输出的形状和元素类型匹配,您可以指定要捐赠相应输入的缓冲区来保存输出。这将减少执行所需的内存,减少捐赠缓冲区的大小。

如果您有类似以下模式的情况,可以使用缓冲捐赠:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state) 

您可以将此视为一种在不可变 JAX 数组上进行内存高效的函数更新的方法。在计算的 XLA 边界内,XLA 可以为您进行此优化,但在 jit/pmap 边界处,您需要向 XLA 保证在调用捐赠函数后不会再使用捐赠的输入缓冲区。

您可以通过在函数jax.jit()jax.pjit()jax.pmap()中使用donate_argnums参数来实现这一点。此参数是位置参数列表(从 0 开始)的索引序列:

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y) 

注意,如果使用关键字参数调用函数,则此方法目前不起作用!以下代码不会捐赠任何缓冲区:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state) 

如果一个参数的缓冲区被捐赠,且其为 pytree,则其所有组件的缓冲区都会被捐赠:

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs) 

不允许捐赠随后在计算中使用的缓冲区,因此在 y 的缓冲区捐赠后,JAX 会报错因为该缓冲区已失效:

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer 

如果捐赠的缓冲区未被使用,则会收到警告,例如因为捐赠的缓冲区多于输出所需:

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0} 

如果没有输出的形状与捐赠匹配,则捐赠可能也不会被使用:

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0} 

使用where时,梯度包含 NaN

如果定义一个使用where来避免未定义值的函数,如果不小心可能会得到一个反向微分的NaN

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN 

简而言之,在grad计算期间,对于未定义的jnp.log(x)的伴随是NaN,并且会累积到jnp.where的伴随中。正确的编写这类函数的方法是确保在部分定义的函数内部有一个jnp.where,以确保伴随始终是有限的:

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok 

除原始jnp.where外可能还需要内部的jnp.where,例如:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y) 

进一步阅读:

基于排序顺序的函数为何梯度为零?

如果定义一个处理输入的函数,并使用依赖于输入相对顺序的操作(例如maxgreaterargsort等),那么可能会惊讶地发现梯度在所有位置都为零。以下是一个例子,我们定义 f(x)为一个阶跃函数,在 x 为负时返回 0,在 x 为正时返回 1:

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0\. 0\. 0\. 1\. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0\. 0\. 0\. 0\. 0.] 

虽然输出对输入有响应,但梯度在所有位置为零可能会令人困惑:毕竟,输出确实随输入而变化,那么梯度怎么可能是零呢?然而,在这种情况下,零确实是正确的结果。

这是为什么?请记住,微分测量的是给定 xf 的变化。对于 x=1.0f 返回 1.0。如果我们微扰 x 使其稍大或稍小,这并不会改变输出,因此根据定义,grad(f)(1.0) 应该为零。对于所有大于零的 f 值,此逻辑同样成立:微扰输入不会改变输出,因此梯度为零。同样,对于所有小于零的 x 值,输出为零。微扰 x 不会改变这个输出,因此梯度为零。这让我们面对 x=0 的棘手情况。当然,如果你向上微扰 x,它会改变输出,但这是有问题的:x 的微小变化会产生函数值的有限变化,这意味着梯度是未定义的。幸运的是,在这种情况下我们还有另一种方法来测量梯度:我们向下微扰函数,此时输出不变,因此梯度为零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但其中一个被定义,另一个未定义,我们使用被定义的那个。根据梯度的这一定义,从数学和数值上来说,此函数的梯度在任何地方都是零。

问题在于我们的函数在 x = 0 处有不连续点。我们的 f 本质上是一个 Heaviside Step Function,我们可以使用 Sigmoid Function 作为平滑替代。当 x 远离零时,Sigmoid 函数近似等于 Heaviside 函数,但在 x = 0 处用一个平滑的、可微的曲线替换不连续性。通过使用 jax.nn.sigmoid(),我们得到一个具有良定义梯度的类似计算:

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0\.   0.27 0.5  0.73 1\.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0\.   0.2  0.25 0.2  0\.  ] 

jax.nn 子模块还有其他常见基于排名的函数的平滑版本,例如 jax.nn.softmax() 可以替换 jax.numpy.argmax() 的使用,jax.nn.soft_sign() 可以替换 jax.numpy.sign() 的使用,jax.nn.softplus()jax.nn.squareplus() 可以替换 jax.nn.relu() 的使用,等等。

我如何将 JAX 追踪器转换为 NumPy 数组?

在运行时检查转换后的 JAX 函数时,您会发现数组值被 Tracer 对象替换:

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5)) 

这将打印如下内容:

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> 

一个常见的问题是如何将这样的追踪器转换回正常的 NumPy 数组。简而言之,无法将追踪器转换为 NumPy 数组,因为追踪器是具有给定形状和数据类型的每一个可能值的抽象表示,而 NumPy 数组是该抽象类的具体成员。有关在 JAX 转换环境中追踪器工作的更多讨论,请参阅 JIT mechanics

将跟踪器转换回数组的问题通常出现在与运行时访问计算中的中间值相关的另一个目标的背景下。例如:

  • 如果您希望出于调试目的在运行时打印跟踪值,您可以考虑使用jax.debug.print()

  • 如果您希望在转换后的 JAX 函数中调用非 JAX 代码,您可以考虑使用jax.pure_callback(),其示例可在纯回调示例中找到。

  • 如果您希望在运行时输入或输出数组缓冲区(例如,从文件加载数据或将数组内容记录到磁盘),您可以考虑使用jax.experimental.io_callback(),其示例可在IO 回调示例中找到。

关于运行时回调的更多信息和它们的使用示例,请参阅JAX 中的外部回调

为什么会有些 CUDA 库加载/初始化失败?

在解析动态库时,JAX 使用通常的动态链接器搜索模式。JAX 将RPATH设置为指向通过 pip 安装的 NVIDIA CUDA 软件包的 JAX 相对位置,如果安装了这些软件包,则优先使用它们。如果ld.so在其通常的搜索路径中找不到您的 CUDA 运行时库,则必须在LD_LIBRARY_PATH中显式包含这些库的路径。确保您的 CUDA 文件可被发现的最简单方法是简单地安装标准的jax[cuda_12]安装选项中包含的nvidia-*-cu12 pip 软件包。

偶尔,即使您确保您的运行时库可被发现,仍可能存在加载或初始化的问题。这类问题的常见原因是运行时 CUDA 库初始化时内存不足。这有时是因为 JAX 将预分配当前可用设备内存的太大一部分以提高执行速度,偶尔会导致没有足够的内存留给运行时 CUDA 库初始化。

在运行多个 JAX 实例、与执行自己的预分配的 TensorFlow 并行运行 JAX,或者在 GPU 被其他进程大量使用的系统上运行 JAX 时,特别容易发生这种情况。如果有疑问,请尝试使用减少预分配来重新运行程序,可以通过减少XLA_PYTHON_CLIENT_MEM_FRACTION(默认为.75)或设置XLA_PYTHON_CLIENT_PREALLOCATE=false来实现。有关更多详细信息,请参阅JAX GPU 内存分配页面。

posted @ 2024-06-21 14:07  绝不原创的飞龙  阅读(52)  评论(0编辑  收藏  举报