JAX-中文文档-三-
JAX 中文文档(三)
有状态计算
JAX 的转换(如jit()
、vmap()
、grad()
)要求它们包装的函数是纯粹的:即,函数的输出仅依赖于输入,并且没有副作用,比如更新全局状态。您可以在JAX sharp bits: Pure functions中找到关于这一点的讨论。
在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可以以多种形式存在。例如:
-
模型参数,
-
优化器状态,以及
-
像BatchNorm这样的有状态层。
本节提供了如何在 JAX 程序中正确处理状态的一些建议。
一个简单的例子:计数器
让我们首先看一个简单的有状态程序:一个计数器。
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
1
2
3
计数器的n
属性在连续调用count
时维护计数器的状态。调用count
的副作用是修改它。
假设我们想要快速计数,所以我们即时编译count
方法。(在这个例子中,这实际上不会以任何方式加快速度,由于很多原因,但把它看作是模型参数更新的玩具模型,jit()
确实产生了巨大的影响)。
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1
1
1
哦不!我们的计数器不能工作了。这是因为
self.n += 1
在count
中涉及副作用:它直接修改了输入的计数器,因此此函数不受jit
支持。这样的副作用仅在首次跟踪函数时执行一次,后续调用将不会重复该副作用。那么,我们该如何修复它呢?
解决方案:显式状态
问题的一部分在于我们的计数器返回值不依赖于参数,这意味着编译输出中包含了一个常数。但它不应该是一个常数 - 它应该依赖于状态。那么,为什么我们不将状态作为一个参数呢?
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
1
2
3
在这个Counter
的新版本中,我们将n
移动到count
的参数中,并添加了另一个返回值,表示新的、更新的状态。现在,为了使用这个计数器,我们需要显式地跟踪状态。但作为回报,我们现在可以安全地使用jax.jit
这个计数器:
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
1
2
3
一个一般的策略
我们可以将同样的过程应用到任何有状态方法中,将其转换为无状态方法。我们拿一个形式如下的类
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
并将其转换为以下形式的类
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
这是一个常见的函数式编程模式,本质上就是处理所有 JAX 程序中状态的方式。
注意,一旦我们按照这种方式重写它,类的必要性就不那么明显了。我们可以只保留stateless_method
,因为类不再执行任何工作。这是因为,像我们刚刚应用的策略一样,面向对象编程(OOP)是帮助程序员理解程序状态的一种方式。
在我们的情况下,CounterV2
类只是一个名称空间,将所有使用 CounterState
的函数集中在一个位置。读者可以思考:将其保留为类是否有意义?
顺便说一句,你已经在 JAX 伪随机性 API 中看到了这种策略的示例,即 jax.random
,在 :ref:pseudorandom-numbers
部分展示。与 Numpy 不同,后者使用隐式更新的有状态类管理随机状态,而 JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。
简单的工作示例:线性回归
现在让我们将这种策略应用到一个简单的机器学习模型上:通过梯度下降进行线性回归。
这里,我们只处理一种状态:模型参数。但通常情况下,你会看到许多种状态在 JAX 函数中交替出现,比如优化器状态、批归一化的层统计数据等。
需要仔细查看的函数是 update
。
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
#
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
#
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
注意,我们手动地将参数输入和输出到更新函数中。
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_2992/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
new_params = jax.tree_map(
进一步探讨
上述描述的策略是任何使用 jit
、vmap
、grad
等转换的 JAX 程序必须处理状态的方式。
如果只涉及两个参数,手动处理参数似乎还可以接受,但如果是有数十层的神经网络呢?你可能已经开始担心两件事情:
-
我们是否应该手动初始化它们,基本上是在前向传播定义中已经编写过的内容?
-
我们是否应该手动处理所有这些事情?
处理这些细节可能有些棘手,但有一些库的示例可以为您解决这些问题。请参阅JAX 神经网络库获取一些示例。
进一步资源
用户指南
用户指南是对 JAX 内特定主题的深入探讨,随着您的 JAX 项目发展成为更大或部署代码库,这些主题变得更为相关。
调试和性能
-
如何在 JAX 中思考
-
对 JAX 程序进行性能分析
-
设备内存分析
-
JAX 中的运行时值调试
-
GPU 性能技巧
-
持久化编译缓存
开发
-
理解 Jaxprs
-
JAX 中的外部回调
-
类型提升语义
-
Pytrees
运行时间
-
提前降低和编译
-
导出和序列化
-
JAX 错误
-
转移保护
自定义操作
- Pallas:一种 JAX 内核语言
如何在 JAX 中思考
原文:
jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
JAX 提供了一个简单而强大的 API 用于编写加速数值代码,但在 JAX 中有效工作有时需要额外考虑。本文档旨在帮助建立对 JAX 如何运行的基础理解,以便您更有效地使用它。
JAX vs. NumPy
关键概念:
-
JAX 提供了一个方便的类似于 NumPy 的接口。
-
通过鸭子类型,JAX 数组通常可以直接替换 NumPy 数组。
-
不像 NumPy 数组,JAX 数组总是不可变的。
NumPy 提供了一个众所周知且功能强大的 API 用于处理数值数据。为方便起见,JAX 提供了 jax.numpy
,它紧密反映了 NumPy API,并为进入 JAX 提供了便捷的入口。几乎可以用 jax.numpy
完成 numpy
可以完成的任何事情:
import matplotlib.pyplot as plt
import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
代码块除了用 jnp
替换 np
外,其余完全相同。正如我们所见,JAX 数组通常可以直接替换 NumPy 数组,用于诸如绘图等任务。
这些数组本身是作为不同的 Python 类型实现的:
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl
Python 的 鸭子类型 允许在许多地方可互换使用 JAX 数组和 NumPy 数组。
然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,一旦创建,其内容无法更改。
这里有一个在 NumPy 中突变数组的例子:
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10 1 2 3 4 5 6 7 8 9]
在 JAX 中,等效操作会导致错误,因为 JAX 数组是不可变的:
%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
对于更新单个元素,JAX 提供了一个 索引更新语法,返回一个更新后的副本:
y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10 1 2 3 4 5 6 7 8 9]
NumPy、lax 和 XLA:JAX API 层次结构
关键概念:
-
jax.numpy
是一个提供熟悉接口的高级包装器。 -
jax.lax
是一个更严格且通常更强大的低级 API。 -
所有 JAX 操作都是基于 XLA – 加速线性代数编译器中的操作实现的。
如果您查看 jax.numpy
的源代码,您会看到所有操作最终都是以 jax.lax
中定义的函数形式表达的。您可以将 jax.lax
视为更严格但通常更强大的 API,用于处理多维数组。
例如,虽然jax.numpy
将隐式促进参数以允许不同数据类型之间的操作,但jax.lax
不会:
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results
The above exception was the direct cause of the following exception:
ValueError: Cannot lower jaxpr with verifier errors:
op requires the same element type for all operands and results
at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module.
如果直接使用jax.lax
,在这种情况下你将需要显式地进行类型提升:
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
除了这种严格性外,jax.lax
还提供了一些比 NumPy 支持的更一般操作更高效的 API。
例如,考虑一个 1D 卷积,在 NumPy 中可以这样表达:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
在幕后,这个 NumPy 操作被转换为由lax.conv_general_dilated
实现的更通用的卷积:
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
这是一种批处理卷积操作,专为深度神经网络中经常使用的卷积类型设计,需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多细节,请参见Convolutions in JAX)。
从本质上讲,所有jax.lax
操作都是 XLA 中操作的 Python 包装器;例如,在这里,卷积实现由XLA:ConvWithGeneralPadding提供。每个 JAX 操作最终都是基于这些基本 XLA 操作表达的,这就是使得即时(JIT)编译成为可能的原因。
要 JIT 或不要 JIT
关键概念:
-
默认情况下,JAX 按顺序逐个执行操作。
-
使用即时(JIT)编译装饰器,可以优化操作序列并一次运行:
-
并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。
所有 JAX 操作都是基于 XLA 表达的事实,使得 JAX 能够使用 XLA 编译器非常高效地执行代码块。
例如,考虑此函数,它对二维矩阵的行进行标准化,表达为jax.numpy
操作:
import jax.numpy as jnp
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
可以使用jax.jit
变换创建函数的即时编译版本:
from jax import jit
norm_compiled = jit(norm)
此函数返回与原始函数相同的结果,达到标准浮点精度:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True
但由于编译(其中包括操作的融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可以比非常数级别快得多(请注意使用block_until_ready()
以考虑 JAX 的异步调度):
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
319 μs ± 1.98 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
272 μs ± 849 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
话虽如此,jax.jit
确实存在一些限制:特别是,它要求所有数组具有静态形状。这意味着一些 JAX 操作与 JIT 编译不兼容。
例如,此操作可以在逐操作模式下执行:
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)
但如果您尝试在 jit 模式下执行它,则会返回错误:
jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
这是因为该函数生成的数组形状在编译时未知:输出的大小取决于输入数组的值,因此与 JIT 不兼容。
JIT 机制:跟踪和静态变量
关键概念:
-
JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型输入的影响。
-
不希望被追踪的变量可以标记为静态
要有效使用 jax.jit
,理解其工作原理是很有用的。让我们在一个 JIT 编译的函数中放几个 print()
语句,然后调用该函数:
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)
注意,打印语句执行,但打印的不是我们传递给函数的数据,而是打印追踪器对象,这些对象代替它们。
这些追踪器对象是 jax.jit
用来提取函数指定的操作序列的基本替代物,编码数组的形状和dtype,但对值是不可知的。然后可以有效地将这个记录的计算序列应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。
当我们在匹配的输入上再次调用编译函数时,无需重新编译,也不打印任何内容,因为结果在编译的 XLA 中计算,而不是在 Python 中:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)
提取的操作序列编码在 JAX 表达式中,简称为 jaxpr。您可以使用 jax.make_jaxpr
转换查看 jaxpr:
from jax import make_jaxpr
def f(x, y):
return jnp.dot(x + 1, y + 1)
make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0
d:f32[4] = add b 1.0
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
注意这一后果:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于追踪的值。例如,这将失败:
@jit
def f(x, neg):
return -x if neg else x
f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2814/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
如果有不希望被追踪的变量,可以将它们标记为静态以供 JIT 编译使用:
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x
f(1, True)
Array(-1, dtype=int32, weak_type=True)
请注意,使用不同的静态参数调用 JIT 编译函数会导致重新编译,所以函数仍然如预期般工作:
f(1, False)
Array(1, dtype=int32, weak_type=True)
理解哪些值和操作将是静态的,哪些将被追踪,是有效使用 jax.jit
的关键部分。
静态与追踪操作
关键概念:
-
就像值可以是静态的或者被追踪的一样,操作也可以是静态的或者被追踪的。
-
静态操作在 Python 中在编译时评估;跟踪操作在 XLA 中在运行时编译并评估。
-
使用
numpy
进行您希望静态的操作;使用jax.numpy
进行您希望被追踪的操作。
静态和追踪值的区别使得重要的是考虑如何保持静态值的静态。考虑这个函数:
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2814/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /tmp/ipykernel_2814/1983583872.py:6 (f)
这会因为找到追踪器而不是整数类型的具体值的 1D 序列而失败。让我们向函数中添加一些打印语句,以了解其原因:
@jit
def f(x):
print(f"x = {x}")
print(f"x.shape = {x.shape}")
print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
# comment this out to avoid the error:
# return x.reshape(jnp.array(x.shape).prod())
f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
注意尽管x
被追踪,x.shape
是一个静态值。然而,当我们在这个静态值上使用jnp.array
和jnp.prod
时,它变成了一个被追踪的值,在这种情况下,它不能用于像reshape()
这样需要静态输入的函数(回想:数组形状必须是静态的)。
一个有用的模式是使用numpy
进行应该是静态的操作(即在编译时完成),并使用jax.numpy
进行应该被追踪的操作(即在运行时编译和执行)。对于这个函数,可能会像这样:
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape),))
f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)
因此,在 JAX 程序中的一个标准约定是import numpy as np
和import jax.numpy as jnp
,这样两个接口都可以用来更精细地控制操作是以静态方式(使用numpy
,一次在编译时)还是以追踪方式(使用jax.numpy
,在运行时优化)执行。
对 JAX 程序进行性能分析
使用 Perfetto 查看程序跟踪
我们可以使用 JAX 分析器生成可以使用Perfetto 可视化工具查看的 JAX 程序的跟踪。目前,此方法会阻塞程序,直到点击链接并加载 Perfetto UI 以打开跟踪为止。如果您希望获取性能分析信息而无需任何交互,请查看下面的 Tensorboard 分析器。
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
计算完成后,程序会提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。
加载链接后,程序执行将继续。链接在打开一次后将不再有效,但将重定向到一个保持有效的新 URL。然后,您可以在 Perfetto UI 中单击“共享”按钮,创建可与他人共享的跟踪的永久链接。
远程分析
在对远程运行的代码进行性能分析(例如在托管的虚拟机上)时,您需要在端口 9001 上建立 SSH 隧道以使链接工作。您可以使用以下命令执行此操作:
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
手动捕获
而不是使用jax.profiler.trace
以编程方式捕获跟踪,您可以通过在感兴趣的脚本中调用jax.profiler.start_server(<port>)
来启动分析服务器。如果您只需在脚本的某部分保持分析服务器活动,则可以通过调用jax.profiler.stop_server()
来关闭它。
脚本运行后并且分析服务器已启动后,我们可以通过运行以下命令手动捕获和跟踪:
$ python -m jax.collect_profile <port> <duration_in_ms>
默认情况下,生成的跟踪信息会被转储到临时目录中,但可以通过传递--log_dir=<自定义目录>
来覆盖此设置。另外,默认情况下,程序将提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。通过传递--no_perfetto_link
命令可以禁用此功能。或者,您也可以将 Tensorboard 指向log_dir
以分析跟踪(参见下面的“Tensorboard 分析”部分)。
TensorBoard 性能分析
TensorBoard 的分析器可用于分析 JAX 程序。Tensorboard 是获取和可视化程序性能跟踪和分析(包括 GPU 和 TPU 上的活动)的好方法。最终结果看起来类似于这样:
安装
TensorBoard 分析器仅与捆绑有 TensorFlow 的 TensorBoard 版本一起提供。
pip install tensorflow tensorboard-plugin-profile
如果您已安装了 TensorFlow,则只需安装tensorboard-plugin-profile
pip 包。请注意仅安装一个版本的 TensorFlow 或 TensorBoard,否则可能会遇到下面描述的“重复插件”错误。有关安装 TensorBoard 的更多信息,请参见www.tensorflow.org/guide/profiler
。
程序化捕获
您可以通过jax.profiler.start_trace()
和jax.profiler.stop_trace()
方法来配置您的代码以捕获性能分析器的追踪。调用start_trace()
时需要指定写入追踪文件的目录。这个目录应该与启动 TensorBoard 时使用的--logdir
目录相同。然后,您可以使用 TensorBoard 来查看这些追踪信息。
例如,要获取性能分析器的追踪:
import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
注意block_until_ready()
调用。我们使用这个函数来确保设备上的执行被追踪到。有关为什么需要这样做的详细信息,请参见异步调度部分。
您还可以使用jax.profiler.trace()
上下文管理器作为start_trace
和stop_trace
的替代方法:
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
要查看追踪信息,请首先启动 TensorBoard(如果尚未启动):
$ tensorboard --logdir=/tmp/tensorboard
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)
在这个示例中,您应该能够在localhost:6006/
加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。
然后,要么在右上角的下拉菜单中选择“Profile”,要么直接访问localhost:6006/#profile
。可用的追踪信息会显示在左侧的“Runs”下拉菜单中。选择您感兴趣的运行,并在“Tools”下选择trace_viewer
。现在您应该能看到执行时间轴。您可以使用 WASD 键来导航追踪信息,点击或拖动以选择事件并查看底部的更多详细信息。有关使用追踪查看器的更多详细信息,请参阅这些 TensorFlow 文档。
您还可以使用memory_viewer
、op_profile
和graph_viewer
工具。
通过 TensorBoard 手动捕获
以下是从运行中的程序中手动触发 N 秒追踪的捕获说明。
-
启动 TensorBoard 服务器:
tensorboard --logdir /tmp/tensorboard/
在
localhost:6006/
处应该能够加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。 -
在您希望进行分析的 Python 程序或进程中,将以下内容添加到开头的某个位置:
import jax.profiler jax.profiler.start_server(9999)
这将启动 TensorBoard 连接到的性能分析器服务器。在继续下一步之前,必须先运行性能分析器服务器。完成后,可以调用
jax.profiler.stop_server()
来关闭它。如果你想要分析一个长时间运行的程序片段(例如长时间的训练循环),你可以将此代码放在程序开头并像往常一样启动程序。如果你想要分析一个短程序(例如微基准测试),一种选择是在 IPython shell 中启动分析器服务器,并在下一步开始捕获后用
%run
运行短程序。另一种选择是在程序开头启动分析器服务器,并使用time.sleep()
给你足够的时间启动捕获。 -
打开
localhost:6006/#profile
,并点击左上角的“CAPTURE PROFILE”按钮。将“localhost:9999”作为分析服务的 URL(这是你在上一步中启动的分析器服务器的地址)。输入你想要进行分析的毫秒数,然后点击“CAPTURE”。 -
如果你想要分析的代码尚未运行(例如在 Python shell 中启动了分析器服务器),请在进行捕获时运行它。
-
捕获完成后,TensorBoard 应会自动刷新。(并非所有 TensorBoard 分析功能都与 JAX 连接,所以初始时看起来可能没有捕获到任何内容。)在左侧的“工具”下,选择
trace_viewer
。现在你应该可以看到执行的时间轴。你可以使用 WASD 键来导航跟踪,点击或拖动选择事件以在底部查看更多详细信息。参见这些 TensorFlow 文档获取有关使用跟踪查看器的更多详细信息。
你也可以使用
memory_viewer
、op_profile
和graph_viewer
工具。
添加自定义跟踪事件
默认情况下,跟踪查看器中的事件大多是低级内部 JAX 函数。你可以使用 jax.profiler.TraceAnnotation
和 jax.profiler.annotate_function()
在你的代码中添加自定义事件和函数。
故障排除
GPU 分析
运行在 GPU 上的程序应该在跟踪查看器顶部附近生成 GPU 流的跟踪。如果只看到主机跟踪,请检查程序日志和/或输出,查看以下错误消息。
如果出现类似 Could not load dynamic library 'libcupti.so.10.1'
的错误
完整错误:
W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
将libcupti.so
的路径添加到环境变量LD_LIBRARY_PATH
中。(尝试使用locate libcupti.so
来找到路径。)例如:
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
即使在做了以上步骤后仍然收到 Could not load dynamic library
错误消息,请检查 GPU 跟踪是否仍然显示在跟踪查看器中。有时即使一切正常,它也会出现此消息,因为它在多个位置查找 libcupti
库。
如果出现类似 failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
的错误
完整错误:
E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED
运行以下命令(注意这将需要重新启动):
echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf
sudo update-initramfs -u
sudo reboot now
查看更多关于此错误的信息,请参阅NVIDIA 的文档。
在远程机器上进行性能分析
如果要分析的 JAX 程序正在远程机器上运行,一种选择是在远程机器上执行上述所有说明(特别是在远程机器上启动 TensorBoard 服务器),然后使用 SSH 本地端口转发从本地访问 TensorBoard Web UI。使用以下 SSH 命令将默认的 TensorBoard 端口 6006 从本地转发到远程机器:
ssh -L 6006:localhost:6006 <remote server address>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 6006:localhost:6006
``` #### 多个 TensorBoard 安装
**如果启动 TensorBoard 失败,并出现类似于`ValueError: Duplicate plugins for name projector`的错误**
这通常是因为安装了两个版本的 TensorBoard 和/或 TensorFlow(例如,`tensorflow`、`tf-nightly`、`tensorboard`和`tb-nightly` pip 包都包含 TensorBoard)。卸载一个 pip 包可能会导致`tensorboard`可执行文件被移除,难以替换,因此可能需要卸载所有内容并重新安装单个版本:
```py
pip uninstall tensorflow tf-nightly tensorboard tb-nightly
pip install tensorflow
Nsight
NVIDIA 的Nsight
工具可用于跟踪和分析 GPU 上的 JAX 代码。有关详情,请参阅Nsight
文档。
设备内存分析
原文:
jax.readthedocs.io/en/latest/device_memory_profiling.html
注意
2023 年 5 月更新:我们建议使用 Tensorboard 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer
标签以获取更详细和易于理解的设备内存使用情况。
JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可用于:
-
查明在特定时间点哪些数组和可执行文件位于 GPU 内存中,或者
-
追踪内存泄漏。
安装
JAX 设备内存分析器生成的输出可使用 pprof (google/pprof) 解释。首先按照其 安装说明 安装 pprof
。撰写时,安装 pprof
需要先安装版本为 1.16+ 的 Go,Graphviz,然后运行
go install github.com/google/pprof@latest
安装 pprof
作为 $GOPATH/bin/pprof
,其中 GOPATH
默认为 ~/go
。
注意
来自 google/pprof 的 pprof
版本与作为 gperftools
软件包一部分分发的同名旧工具不同。gperftools
版本的 pprof
不适用于 JAX。
理解 JAX 程序如何使用 GPU 或 TPU 内存
设备内存分析器的常见用途是找出为何 JAX 程序使用大量 GPU 或 TPU 内存,例如调试内存不足问题。
要将设备内存分析保存到磁盘,使用 jax.profiler.save_device_memory_profile()
。例如,考虑以下 Python 程序:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
如果我们首先运行上述程序,然后执行
pprof --web memory.prof
pprof
打开一个包含设备内存分析调用图格式的 Web 浏览器:
调用图是在每个活动缓冲区分配的 Python 栈的可视化。例如,在这个特定情况下,可视化显示 func2
及其被调用者负责分配了 76.30MB,其中 38.15MB 是在从 func1
到 func2
的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档。
使用 jax.jit()
编译的函数对设备内存分析器不透明。也就是说,任何在 jit
编译函数内部分配的内存都将归因于整个函数。
在本例中,调用 block_until_ready()
是为了确保在收集设备内存分析之前 func2
完成。有关更多详细信息,请参阅异步调度。
调试内存泄漏
我们还可以使用 JAX 设备内存分析器,通过使用 pprof
来可视化在不同时间点获取的两个设备内存配置文件中的内存使用情况变化,以追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
如果我们仅在执行结束时可视化设备内存配置文件(memory9.prof
),则可能不明显,即 anotherfunc
中的每次循环迭代都会累积更多的设备内存分配:
pprof --web memory9.prof
在 afunction
内部的大型但固定分配主导配置文件,但不会随时间增长。
通过使用 pprof
的 --diff_base
功能 来可视化循环迭代中内存使用情况的变化,我们可以找出程序内存使用量随时间增加的原因:
pprof --web --diff_base memory1.prof memory9.prof
可视化显示,内存增长可以归因于 anotherfunc
中对 normal
的调用。
在 JAX 中进行运行时值调试
是否遇到梯度爆炸?NaN 使你牙齿咬紧?只想查看计算中间值?请查看以下 JAX 调试工具!本页提供了 TL;DR 摘要,并且您可以点击底部的“阅读更多”链接了解更多信息。
目录:
-
使用
jax.debug
进行交互式检查 -
使用 jax.experimental.checkify 进行功能错误检查
-
使用 JAX 的调试标志抛出 Python 错误
使用 jax.debug
进行交互式检查
TL;DR 使用 jax.debug.print()
在 jax.jit
、jax.pmap
和 pjit
装饰的函数中将值打印到 stdout,并使用 jax.debug.breakpoint()
暂停执行编译函数以检查调用堆栈中的值:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
点击此处了解更多!
使用 jax.experimental.checkify
进行功能错误检查
TL;DR Checkify 允许您向 JAX 代码添加 jit
可用的运行时错误检查(例如越界索引)。使用 checkify.checkify
转换以及类似断言的 checkify.check
函数,向 JAX 代码添加运行时检查:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
您还可以使用 checkify 自动添加常见检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
点击此处了解更多!
使用 JAX 的调试标志抛出 Python 错误
TL;DR 启用 jax_debug_nans
标志,自动检测在 jax.jit
编译的代码中生成 NaN 时(但不在 jax.pmap
或 jax.pjit
编译的代码中),并启用 jax_disable_jit
标志以禁用 JIT 编译,从而使用传统的 Python 调试工具如 print
和 pdb
。
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
点击此处了解更多!
阅读更多
-
jax.debug.print
和jax.debug.breakpoint
-
checkify
转换 -
JAX 调试标志
jax.debug.print
和 jax.debug.breakpoint
原文:
jax.readthedocs.io/en/latest/debugging/print_breakpoint.html
jax.debug
包为检查在 JIT 函数中的值提供了一些有用的工具。
使用 jax.debug.print
和其他调试回调进行调试
TL;DR 使用 jax.debug.print()
在 jit
和 pmap
装饰函数中将跟踪的数组值打印到标准输出:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
对于一些转换,如 jax.grad
和 jax.vmap
,可以使用 Python 的内置 print
函数打印数值。但是 print
在 jax.jit
或 jax.pmap
下不起作用,因为这些转换会延迟数值评估。因此,请使用 jax.debug.print
代替!
语义上,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
除了可以被 JAX 分阶段化和转换外。有关更多详细信息,请参阅 API 参考
。
注意,fmt
不能是 f-string,因为 f-string 会立即格式化,而对于 jax.debug.print
,我们希望延迟到稍后再格式化。
何时使用“debug”打印?
对于动态(即跟踪的)数组值在 JAX 转换如 jit
、vmap
等中,应使用 jax.debug.print
进行打印。对于静态值(如数组形状或数据类型),可以使用普通的 Python print
语句。
为什么使用“debug”打印?
以调试为名,jax.debug.print
可以显示有关计算如何评估的信息:
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
注意,打印的结果是以不同的顺序显示的!
通过揭示这些内部工作,jax.debug.print
的输出不遵守 JAX 的通常语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
计算相同的东西(以不同的方式)。然而,这些评估顺序的细节正是我们调试时想要看到的!
因此,在重视语义保证时,请使用 jax.debug.print
进行调试。
更多 jax.debug.print
的例子
除了上述使用 jit
和 vmap
的例子外,还有几个需要记住的例子。
在 jax.pmap
下打印
当使用 jax.pmap
时,jax.debug.print
可能会被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下打印
在 jax.grad
下,jax.debug.print
只会在前向传播时打印:
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
这种行为类似于 Python 内置的 print
在 jax.grad
下的工作方式。但在这里使用 jax.debug.print
,即使调用者应用 jax.jit
,行为也是相同的。
要在反向传播中打印,只需使用 jax.custom_vjp
:
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印
jax.debug.print
在其他转换如 xmap
和 pjit
中同样适用。
使用 jax.debug.callback
更多控制
实际上,jax.debug.print
是围绕 jax.debug.callback
的一个轻便封装,可以直接使用以更好地控制字符串格式化或输出类型。
语义上,jax.debug.callback
大致等同于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print
类似,这些回调只应用于调试输出,比如打印或绘图。打印和绘图相对无害,但如果用于其他用途,它的行为在转换中可能会让你感到意外。例如,不安全地用于计时操作是不安全的,因为回调可能会被重新排序并且是异步的(见下文)。
锐利的部分
像大多数 JAX API 一样,如果使用不当,jax.debug.print
也会给你带来麻烦。
打印结果的顺序
当 jax.debug.print
的不同调用涉及彼此不依赖的参数时,在分阶段时可能会被重新排序,例如通过 jax.jit
:
@jax.jit
def f(x, y):
jax.debug.print("x: {}", x)
jax.debug.print("y: {}", y)
return x + y
f(2., 3.)
# Prints: x: 2.0
# y: 3.0
# OR
# Prints: y: 3.0
# x: 2.0
为什么?在幕后,编译器获得了一个计算的功能表示,其中 Python 函数的命令顺序丢失,只有数据依赖性保留。对于功能纯粹的代码用户来说,这种变化是看不见的,但是在像打印这样的副作用存在时,就会显而易见。
要保持 jax.debug.print
在 Python 函数中的原始顺序,可以使用 jax.debug.print(..., ordered=True)
,这将确保打印的相对顺序保持不变。但是在 jax.pmap
和涉及并行性的其他 JAX 转换中使用 ordered=True
会引发错误,因为在并行执行中无法保证顺序。
异步回调
根据后端不同,jax.debug.print
可能会异步执行,即不在主程序线程中。这意味着值可能在您的 JAX 函数返回值后才被打印到屏幕上。
@jax.jit
def f(x):
jax.debug.print("x: {}", x)
return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.
要阻塞函数中的 jax.debug.print
,您可以调用 jax.effects_barrier()
,它会等待函数中任何剩余的副作用也完成:
@jax.jit
def f(x):
jax.debug.print("x: {}", x)
return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>
性能影响
不必要的实现
虽然 jax.debug.print
设计为性能影响最小,但它可能会干扰编译器优化,并且可能会影响 JAX 程序的内存配置文件。
def f(w, b, x):
logits = w.dot(x) + b
jax.debug.print("logits: {}", logits)
return jax.nn.relu(logits)
在这个例子中,我们在线性层和激活函数之间打印中间值。像 XLA 这样的编译器可以执行融合优化,可以避免在内存中实现 logits
。但是当我们在 logits
上使用 jax.debug.print
时,我们强制这些中间值被实现,可能会减慢程序速度并增加内存使用。
此外,当使用 jax.debug.print
与 jax.pjit
时,会发生全局同步,将值实现在单个设备上。
回调开销
jax.debug.print
本质上会在加速器和其主机之间进行通信。底层机制因后端而异(例如 GPU vs TPU),但在所有情况下,我们需要将打印的值从设备复制到主机。在 CPU 情况下,此开销较小。
此外,当使用 jax.debug.print
与 jax.pjit
时,会发生全局同步,增加了一些额外开销。
jax.debug.print
的优势和限制
优势
-
打印调试简单直观
-
jax.debug.callback
可用于其他无害的副作用
限制
-
添加打印语句是一个手动过程
-
可能会对性能产生影响
使用 jax.debug.breakpoint()
进行交互式检查
TL;DR 使用 jax.debug.breakpoint()
暂停执行您的 JAX 程序以检查值:
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是 jax.debug.callback(...)
的一种应用,用于捕获调用堆栈信息。因此它与 jax.debug.print
具有相同的转换行为(例如,对 jax.debug.breakpoint()
进行 vmap
-ing 会将其展开到映射的轴上)。
用法
在编译的 JAX 函数中调用 jax.debug.breakpoint()
会在命中断点时暂停程序。您将看到一个类似 pdb
的提示符,允许您检查调用堆栈中的值。与 pdb
不同的是,您不能逐步执行程序,但可以恢复执行。
调试器命令:
-
help
- 打印出可用的命令 -
p
- 评估表达式并打印其结果 -
pp
- 评估表达式并漂亮地打印其结果 -
u(p)
- 上移一个堆栈帧 -
d(own)
- 下移一个堆栈帧 -
w(here)/bt
- 打印出回溯 -
l(ist)
- 打印出代码上下文 -
c(ont(inue))
- 恢复程序的执行 -
q(uit)/exit
- 退出程序(在 TPU 上不起作用)
示例
与 jax.lax.cond
结合使用
当与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
锐利的特性
因为 jax.debug.breakpoint
只是 jax.debug.callback
的一种应用,所以它与 jax.debug.print
一样具有锐利的特性,但也有一些额外的注意事项:
-
jax.debug.breakpoint
比jax.debug.print
更多地实现了中间值,因为它强制实现了调用堆栈中的所有值。 -
jax.debug.breakpoint
的运行时开销比jax.debug.print
更大,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优势和限制
优势
-
简单、直观且(在某种程度上)标准
-
可以同时检查多个值,上下跟踪调用堆栈。
限制
-
可能需要使用多个断点来准确定位错误的源头
-
会产生许多中间值
checkify
转换
原文:
jax.readthedocs.io/en/latest/debugging/checkify_guide.html
TL;DR checkify
允许您向您的 JAX 代码添加可jit
的运行时错误检查(例如越界索引)。使用checkify.checkify
转换与类似断言的checkify.check
函数一起向 JAX 代码添加运行时检查:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
您还可以使用checkify
来自动添加常见的检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
功能化检查
与 assert 类似的检查 API 本身不是函数纯粹的:它可以作为副作用引发 Python 异常,就像 assert 一样。因此,它不能与jit
、pmap
、pjit
或scan
分阶段执行:
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是checkify
转换功能化(或卸载)这些效果。一个经过checkify
转换的函数将错误值作为新输出返回,并保持函数纯粹。这种功能化意味着checkify
转换的函数可以与我们喜欢的任何分阶段/转换进行组合:
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
JAX 为什么需要checkify
?
在某些 JAX 转换下,您可以使用普通的 Python 断言表达运行时错误检查,例如仅使用jax.grad
和jax.numpy
时。
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
但是普通的断言在jit
、pmap
、pjit
或scan
中不起作用。在这些情况下,数值计算是在 Python 执行期间被分阶段地进行评估,因此数值值不可用:
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
在组合多个转换时,JAX 转换语义依赖于函数纯度,因此我们如何在不干扰所有这些的情况下提供一个错误机制?除了需要一个新的 API 之外,情况还更加棘手:XLA HLO 不支持断言或抛出错误,因此即使我们有一个能够分阶段断言的 JAX API,我们如何将这些断言降低到 XLA 呢?
您可以想象手动向函数添加运行时检查并通过值来传递表示错误:
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
错误是由函数计算出的常规值,并且错误是在f_checked
外部引发的。f_checked
是函数式纯粹的,因此我们知道通过构造,它已经可以与jit
、pmap
、pjit
、scan
以及所有 JAX 的转换一起工作。唯一的问题是这些管道可能会很麻烦!
checkify
为您完成了这个重写工作:包括通过函数传递错误值、将检查重写为布尔操作并将结果与跟踪的错误值合并,并将最终错误值作为检查函数的输出返回:
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1\. must be positive! (check failed at <...>:2 (f))
我们称这个过程为功能化或者通过调用检查引入的效果。 (在上面的“手动”示例中,错误值只是一个布尔值。checkify
的错误值在概念上类似,但还跟踪错误消息并公开抛出和获取方法;参见jax.experimental.checkify
)。checkify.check
还允许您通过将其作为格式参数提供给错误消息来将运行时值添加到您的错误消息中。
您现在可以手动为您的代码添加运行时检查,但 checkify
也可以自动添加常见错误的检查!考虑这些错误情况:
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
默认情况下,checkify
仅释放 checkify.check
,不会捕获类似上述的错误。但如果您要求,checkify
也会自动在您的代码中添加检查。
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
基于 Sets 的 API,用于选择要启用的自动检查。详见 jax.experimental.checkify
获取更多详情。
在 JAX 变换下的 checkify
。
如上例所示,checkified 函数可以愉快地进行 jitted 处理。以下是 checkify
与其他 JAX 变换的几个示例。请注意,checkified 函数在功能上是纯粹的,并且应与所有 JAX 变换轻松组合!
jit
您可以安全地向 checkified 函数添加 jax.jit
,或者 checkify
一个 jitted 函数,两者都可以正常工作。
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap
/pmap
您可以 vmap
和 pmap
checkified 函数(或 checkify
映射函数)。映射一个 checkified 函数将为您提供一个映射的错误,该错误可以包含映射维度的每个元素的不同错误。
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
然而,checkify-of-vmap
将产生单个(未映射)的错误!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit
对于 checkified 函数的 pjit
可以正常工作,您只需为错误值输出的 out_axis_resources
指定额外的 None
。
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad
如果您使用 checkify-of-grad
,还将对您的梯度计算进行检查:
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
请注意,f
中没有乘法,但在其梯度计算中有乘法(这就是生成 NaN 的地方!)。因此,请使用 checkify-of-grad
为前向和后向传递操作添加自动检查。
checkify.check
仅应用于函数的主值。如果您想在梯度值上使用 check
,请使用 custom_vjp
:
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
jax.experimental.checkify
的优势和限制
优势
-
您可以在任何地方使用它(错误只是“值”,并在像其他值一样的转换下直观地表现)。
-
自动插装:您无需对代码进行本地修改。相反,
checkify
可以为其所有部分添加插装!
限制
-
添加大量运行时检查可能很昂贵(例如,对每个原语添加 NaN 检查将增加计算中的许多操作)。
-
需要将错误值从函数中线程化并手动抛出错误。如果未显式抛出错误,则可能会错过错误!
-
抛出一个错误值将在主机上实现该错误值,这意味着它是一个阻塞操作,这会打败 JAX 的异步先行运行。
JAX 调试标志
JAX 提供了标志和上下文管理器,可更轻松地捕获错误。
jax_debug_nans
配置选项和上下文管理器
简而言之 启用 jax_debug_nans
标志可自动检测在 jax.jit
编译的代码中产生 NaN(但不适用于 jax.pmap
或 jax.pjit
编译的代码)。
jax_debug_nans
是一个 JAX 标志,当启用时,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理——如果从 JIT 编译函数检测到 NaN 输出,函数会急切地重新运行(即不经过编译),并在产生 NaN 的具体原始基元处引发错误。
用法
如果您想追踪函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:
-
设置
JAX_DEBUG_NANS=True
环境变量; -
在主文件顶部附近添加
jax.config.update("jax_debug_nans", True)
; -
在主文件添加
jax.config.parse_flags_with_absl()
,然后像--jax_debug_nans=True
这样使用命令行标志设置选项;
示例
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
jax_debug_nans
的优势和限制
优势
-
易于应用
-
精确检测产生 NaN 的位置
-
抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
-
与
jax.pmap
或jax.pjit
不兼容 -
急切重新运行函数可能会很慢
-
误报阳性(例如故意创建 NaN)
jax_disable_jit
配置选项和上下文管理器
简而言之 启用 jax_disable_jit
标志可禁用 JIT 编译,从而启用传统的 Python 调试工具如 print
和 pdb
。
jax_disable_jit
是一个 JAX 标志,当启用时,会在整个 JAX 中禁用 JIT 编译(包括在控制流函数如 jax.lax.cond
和 jax.lax.scan
中)。
用法
您可以通过以下方式禁用 JIT 编译:
-
设置
JAX_DISABLE_JIT=True
环境变量; -
在主文件顶部附近添加
jax.config.update("jax_disable_jit", True)
; -
在主文件添加
jax.config.parse_flags_with_absl()
,然后像--jax_disable_jit=True
这样使用命令行标志设置选项;
示例
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
jax_disable_jit
的优势和限制
优势
-
易于应用
-
启用 Python 内置的
breakpoint
和print
-
抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
-
与
jax.pmap
或jax.pjit
不兼容 -
在没有 JIT 编译的情况下运行函数可能会很慢
GPU 性能提示
本文档专注于神经网络工作负载的性能提示。
矩阵乘法精度
在像 Nvidia A100 一代或更高的最新 GPU 代中,将大多数计算以 bfloat16
精度执行可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
实例化 Dense
层。以下是一些代码示例:
-
在 Flax LM1B example 中,
Dense
模块也可以使用可配置的数据类型 进行实例化,其 默认值 为 bfloat16。 -
在 MaxText 中,
DenseGeneral
模块也可用可配置的数据类型 进行实例化,其 默认值为 bfloat16。
XLA 性能标志
注意
JAX-Toolbox 还有一个关于 NVIDIA XLA 性能 FLAGS 的页面。
XLA 标志的存在和确切行为可能取决于 jaxlib
版本。
截至 jaxlib==0.4.18
(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。其中一些与多 GPU 之间的通信相关,因此仅在多设备运行计算时才相关,而其他一些与每个设备上的代码生成相关。
未来版本中可能会默认设置其中一些。
这些标志可以通过 XLA_FLAGS
shell 环境变量进行设置。例如,我们可以将其添加到 Python 文件的顶部:
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
更多示例,请参阅 XLA Flags recommended for Pax training on Nvidia GPUs。
代码生成标志
-
–xla_gpu_enable_triton_softmax_fusion 此标志启用基于 Triton 代码生成支持的模式匹配自动 softmax 融合。默认值为 False。
-
–xla_gpu_triton_gemm_any 使用基于 Triton 的 GEMM(矩阵乘法)发射器支持的任何 GEMM。默认值为 False。
通信标志
-
–xla_gpu_enable_async_collectives 此标志启用诸如
AllReduce
、AllGather
、ReduceScatter
和CollectivePermute
等集体操作以异步方式进行。异步通信可以将跨核心通信与计算重叠。默认值为 False。 -
–xla_gpu_enable_latency_hiding_scheduler 这个标志启用了延迟隐藏调度器,可以高效地将异步通信与计算重叠。默认值为 False。
-
–xla_gpu_enable_pipelined_collectives 在使用管道并行时,此标志允许将(i+1)层权重的
AllGather
与第 i 层的计算重叠。它还允许将(i+1)层权重的Reduce
/ReduceScatter
与第 i 层的计算重叠。默认值为 False。在启用此标志时存在一些错误。 -
–xla_gpu_collective_permute_decomposer_threshold 当执行GSPMD pipelining时,这个标志非常有用。设置一个非零的阈值会将
CollectivePermute
分解为CollectivePermuteReceiveDone
和CollectivePermuteSendDone
对,从而可以在每个对应的ReceiveDone
/SendDone
对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为大于 0 的阈值,例如--xla_gpu_collective_permute_decomposer_threshold=1024
,可以启用此功能。 -
–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的
AllGather
/ReduceScatter
/AllReduce
组合成一个大的AllGather
/ReduceScatter
/AllReduce
,以减少跨设备通信所花费的时间。例如,在基于 Transformer 的工作负载上,可以考虑将AllGather
/ReduceScatter
阈值调高,以至少组合一个 Transformer 层的权重AllGather
/ReduceScatter
。默认情况下,combine_threshold_bytes
设置为 256。
NCCL 标志
这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。
多进程
我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize()
API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。
持久编译缓存
原文:
jax.readthedocs.io/en/latest/persistent_compilation_cache.html
JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。
使用
当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
有关cache-location
的更多详细信息,请参见以下各节。
set_cache_dir()
是设置cache-location
的另一种方法。
本地文件系统
cache-location
可以是本地文件系统上的目录。例如:
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
注意:缓存没有实现驱逐机制。如果cache-location
是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。
Google Cloud
在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:
-
在与工作负载运行地区相同的地方创建存储桶。
-
在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
-
对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
-
对于存储桶的默认存储类别,请使用“标准”。
-
将软删除策略设置为最短期限:7 天。
-
将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用
age
作为生命周期条件,使用Delete
作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。 -
所有加密策略都受支持。
假设gs://jax-cache
是 GCS 存储桶,请设置如下cache-location
:
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
2022-06-21 数据科学 IPython 笔记本 9.2 NumPy 简介
2020-06-21 PyTorch 1.0 中文官方教程:使用字符级别特征的RNN网络生成姓氏