JAX-中文文档-五-
JAX 中文文档(五)
形状多态性
当使用 JIT 模式的 JAX 时,函数将被跟踪、降级到 StableHLO,并针对每种输入类型和形状组合进行编译。在导出函数并在另一个系统上反序列化后,我们就无法再使用 Python 源代码,因此无法重新跟踪和重新降级它。形状多态性是 JAX 导出的一个特性,允许一些导出函数用于整个输入形状家族。这些函数在导出时只被跟踪和降级一次,并且Exported
对象包含编译和执行该函数所需的信息,可以在许多具体输入形状上进行编译和执行。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,例如下面的示例:
>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x): # f: f32[a, b]
... return jnp.concatenate([x, x], axis=1)
>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")
>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)
>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)
>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)
注意,此类函数仍会按需为每个具体输入形状重新编译。仅跟踪和降级是保存的。
在上面的示例中,jax.export.symbolic_shape()
用于解析符号形状的字符串表示,将其转换为可以用于构造形状的维度表达式对象(类型为 _DimExpr
)。维度表达式对象重载了大多数整数运算符,因此在大多数情况下可以像使用整数常量一样使用它们。详细信息请参阅使用维度变量进行计算。
另外,我们提供了jax.export.symbolic_args_specs()
,可用于根据多态形状规范构建jax.ShapeDtypeStruct
对象的 pytrees:
>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
... return x + y
>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))
注意多态形状规范中的 "a, ..."
如何包含占位符 ...
,以从参数 (x, y)
的具体形状中填充。占位符 ...
代表 0 个或多个维度,而占位符 _
代表一个维度。jax.export.symbolic_args_specs()
支持参数的 pytrees,用于填充 dtypes 和任何占位符。该函数将构造与传递给它的参数结构相匹配的参数规范 pytree (jax.ShapeDtypeStruct
)。在某些情况下,多个参数应用相同规范的前缀,如上例所示。请参阅如何将可选参数匹配到参数。
几个形状规范的示例:
-
("(b, _, _)", None)
可以用于具有两个参数的函数,第一个是具有应为符号的批处理前导维度的三维数组。基于实际参数专门化第一个参数的其他维度和第二个参数的形状。请注意,如果第一个参数是具有相同前导维度但可能具有不同尾部维度的多个三维数组的 pytree,则相同的规范也适用。第二个参数的值None
表示该参数不是符号化的。等效地,可以使用...
。 -
("(batch, ...)", "(batch,)")
指定两个参数具有匹配的前导维度,第一个参数至少具有秩为 1,第二个具有秩为 1。
形状多态的正确性
我们希望信任导出的程序在编译和执行适用于任何具体形状时产生与原始 JAX 程序相同的结果。更确切地说:
对于任何 JAX 函数f
和包含符号形状的参数规范arg_spec
,以及任何形状与arg_spec
匹配的具体参数arg
:
-
如果 JAX 本地执行在具体参数上成功:
res = f(arg)
, -
如果导出使用符号形状成功:
exp = export.export(f)(arg_spec)
, -
编译和运行导出程序将会成功并得到相同的结果:
res == exp.call(arg)
非常重要的是理解f(arg)
有自由重新调用 JAX 追踪机制,实际上对于每个不同的具体arg
形状都会这样做,而exp.call(arg)
的执行不能再使用 JAX 追踪(这种执行可能发生在无法访问f
源代码的环境中)。
确保这种正确性形式是困难的,在最困难的情况下,导出会失败。本章的其余部分描述了如何处理这些失败。
使用维度变量进行计算
JAX 跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 将它们计算为涉及维度变量的符号形状表达式。维度变量代表大于或等于 1 的整数值。这些符号表达式可以表示应用算术运算符(add、sub、mul、floordiv、mod,包括 NumPy 变体 np.sum
、np.prod
等)在维度表达式和整数上的结果(int
、np.int
,或者通过operator.index
可转换的任何内容)。这些符号维度随后可以在 JAX 原语和 API 的形状参数中使用,例如在jnp.reshape
、jnp.arange
、切片索引等。
例如,在以下代码中展平二维数组时,计算x.shape[0] * x.shape[1]
将计算符号维度4 * b
作为新形状:
>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)
可以将维度表达式明确转换为 JAX 数组,例如jnp.array(x.shape[0])
甚至jnp.array(x.shape)
。这些操作的结果可以用作常规的 JAX 数组,但不能再作为形状中的维度使用。
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)
>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
当符号维度与非整数(如 float
、np.float
、np.ndarray
或 JAX 数组)进行算术运算时,它会自动转换为 JAX 数组,使用 jnp.array
。例如,在下面的函数中,x.shape[0]
的所有出现都会被隐式转换为 jnp.array(x.shape[0])
,因为它们与非整数标量或 JAX 数组参与了运算:
>>> exp = export.export(jax.jit(
... lambda x: (5. + x.shape[0],
... x.shape[0] - np.arange(5, dtype=jnp.int32),
... x + x.shape[0] + jnp.sin(x.shape[0]))))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
ShapedArray(int32[5]),
ShapedArray(float32[b], weak_type=True))
>>> exp.call(jnp.ones((3,), jnp.int32))
(Array(8., dtype=float32, weak_type=True),
Array([ 3, 2, 1, 0, -1], dtype=int32),
Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))
另一个典型的例子是计算平均值(注意 x.shape[0]
如何自动转换为 JAX 数组):
>>> exp = export.export(jax.jit(
... lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)
存在形状多态性的错误
大多数 JAX 代码假定 JAX 数组的形状是整数元组,但是使用形状多态性时,某些维度可能是符号表达式。这可能导致多种错误。例如,我们可以遇到通常的 JAX 形状检查错误:
>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
... jax.ShapeDtypeStruct((v,), dtype=np.int32),
... jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).
>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
... jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).
我们可以通过指定参数的形状(v, v)
来修复上述矩阵乘法示例。
部分支持符号维度的比较
在 JAX 内部存在多个形状比较的相等性和不等式比较,例如用于形状检查或甚至用于为某些原语选择实现。比较支持如下:
-
支持等式,但有一个注意事项:如果两个符号维度在所有维度变量的赋值下都表示相同的值,则等式求值为
True
,例如对于b + b == 2*b
;否则等式求值为False
。关于此行为的重要后果,请参见下文讨论。 -
不相等总是等于等式的否定。
-
不等式部分支持,类似于部分等式。然而,在这种情况下,我们考虑维度变量只取严格正整数。例如,
b >= 1
、b >= 0
、2 * a + b >= 3
是True
,而b >= 2
、a >= b
、a - b >= 0
是不确定的并会导致异常。
在无法将比较操作解析为布尔值的情况下,我们会引发 InconclusiveDimensionOperation
。例如,
import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
如果出现 InconclusiveDimensionOperation
,您可以尝试几种策略:
-
如果您的代码使用内置的
max
或min
,或者使用np.max
或np.min
,那么可以将它们替换为core.max_dim
和core.min_dim
,这样可以将不等式比较延迟到编译时,当形状已知时。 -
尝试使用
core.max_dim
和core.min_dim
重写条件语句,例如,代替d if d > 0 else 0
,您可以写成core.max_dim(d, 0)
。 -
尝试重写代码,减少对维度应为整数的依赖,并依赖于符号维度在大多数算术运算中作为整数的鸭子类型。例如,代替
int(d) + 5
写成d + 5
。 -
按照下面的说明指定符号约束。
用户指定的符号约束
默认情况下,JAX 假定所有维度变量的取值大于或等于 1,并试图从中推导出其他简单的不等式,例如:
-
a + 2 >= 3
, -
a * 2 >= 1
, -
a + b + c >= 3
, -
a // 4 >= 0
,a**2 >= 1
,等等。
如果将符号形状规范更改为维度大小的隐式约束,可以避免一些不等比较失败。例如,
-
你可以使用
2*b
作为维度来约束它为偶数且大于或等于 2。 -
你可以使用
b + 15
作为维度来约束它至少为 16。例如,如果没有+ 15
部分,以下代码会失败,因为 JAX 将希望验证切片大小至多不超过轴大小。
>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))
这些隐式符号约束用于决定比较,并且在编译时检查,如下所述。
你也可以指定显式符号约束:
>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
... constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
... jax.ShapeDtypeStruct((a, b), dtype=np.int32))
约束与隐式约束一起形成一个连接。你可以指定 >=
、<=
和 ==
约束。目前,JAX 对符号约束的推理支持有限:
-
对于形式为变量大于或等于或小于或等于常数的约束,你可以得到最大的功效。例如,从
a >= 16
和b >= 8
的约束中,我们可以推断出a + 2*b >= 32
。 -
当约束涉及更复杂的表达式时,例如从
a >= b + 8
我们可以推断出a - b >= 8
,但不能推断出a >= 9
。我们可能会在未来在这个领域有所改进。 -
等式约束被视为归一化规则。例如,
floordiv(a, b) = c
通过将所有左侧的出现替换为右侧来工作。只能有左侧是因子乘积的等式约束,例如a * b
,或4 * a
,或floordiv(a, b)
。因此,左侧不能包含顶层的加法或减法。
符号约束还可以帮助绕过 JAX 推理机制中的限制。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3
,即符号表达式 mod(b, 3)
,小于或等于轴大小 b
。对于所有严格正值的 b
来说,这是真的,但这并不是 JAX 符号比较规则能够证明的。因此,以下代码会引发错误:
from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.
一种选择是将代码限制为仅在轴大小是 3
的倍数上运行(通过在形状中用 3*b
替换 b
)。然后,JAX 将能够将模运算 mod(3*b, 3)
简化为 0
。另一种选择是添加一个带有确切不确定不等式的符号约束,JAX 正试图证明:
>>> b, = export.symbolic_shape("b",
... constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
就像隐式约束一样,显式符号约束在编译时使用相同的机制进行检查,如下所述。
符号维度范围
符号约束存储在一个αn jax.export.SymbolicScope
对象中,它会隐式地为每次调用jax.export.symbolic_shapes()
创建。您必须小心,不要混合使用不同范围的符号表达式。例如,下面的代码将失败,因为a1
和a2
使用了不同的范围(由不同调用jax.export.symbolic_shape()
创建):
>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> a1 + a2
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
a >= 8
源自单次调用jax.export.symbolic_shape()
的符号表达式共享一个范围,并且可以在算术操作中混合使用。结果也将共享相同的范围。
您可以重复使用范围:
>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a`
>>> a + b # Allowed
b + a
您也可以显式创建范围:
>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d # Allowed
d + c
JAX 跟踪使用部分以形状为键的缓存,并且如果它们使用不同的范围,则打印相同的符号形状将被视为不同的。
相等性比较的注意事项
相等比较返回False
,对于b + 1 == b
或b == 0
(在这种情况下,对于所有维度变量的值,维度肯定不同),但对于b == 1
和a == b
也是如此。这是不稳定的,我们应该引发core.InconclusiveDimensionOperation
,因为在某些估值下结果应该是True
,在其他估值下应该是False
。我们选择使相等性变得全面,从而允许不稳定性,因为否则在哈希碰撞存在时(哈希维度表达式或包含它们的对象时,如形状,core.AbstractValue
,core.Jaxpr
),我们可能会遇到虚假错误。除了哈希错误外,相等性的部分语义还会导致以下表达式的错误b == a or b == b
或b in [a, b]
,即使我们改变比较的顺序也能避免错误。
形式为if x.shape[0] != 1: raise NiceErrorMessage
的代码在处理相等性时也是合理的,但形式为if x.shape[0] != 1: return 1
的代码是不稳定的。
维度变量必须能够从输入形状中解决
目前,当调用导出对象时,通过数组参数的形状间接传递维度变量的值是唯一的方法。例如,可以在调用类型为f32[b]
的第一个参数的形状中推断出b
的值。这对大多数用例都很有效,并且它反映了 JIT 函数的调用约定。
有时您可能希望导出一个由整数值参数化的函数,这些值确定程序中的某些形状。例如,我们可能希望导出下面定义的函数my_top_k
,其由值k
参数化,该值确定了结果的形状。下面的尝试将导致错误,因为维度变量k
不能从输入x: i32[4, 10]
的形状中推导出来:
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10
... return lax.top_k(x, k)[0] # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))
>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])
>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])
>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)
Traceback (most recent call last):
KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
未来,我们可能会添加额外的机制来传递维度变量的值,除了通过输入形状隐式传递外。与此同时,解决上述用例的方法是将函数参数k
替换为形状为(0, k)
的数组,这样k
可以从数组的输入形状中推导出来。第一个维度为 0 是为了确保整个数组为空,在调用导出函数时不会有性能惩罚。
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10]
... return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
... jax.ShapeDtypeStruct((0, k), dtype=np.int32),
... x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))
>>> exp.out_avals[0]
ShapedArray(int32[4,k])
>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9, 8, 7],
[19, 18, 17],
[29, 28, 27],
[39, 38, 37]], dtype=int32)
另一种可能出现错误的情况是一些维度变量出现在输入形状中,但以 JAX 目前无法解决的非线性表达式形式出现:
>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
... jax.ShapeDtypeStruct((a * a,), dtype=np.int32))
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a²,).
Unprocessed specifications: 'a²' for dimension size args[0].shape[0].
形状断言错误
JAX 假设维度变量在严格正整数范围内,这一假设在为具体输入形状编译代码时被检查。
例如,对于符号输入形状(b, b, 2*d)
,当使用实际参数arg
调用时,JAX 将生成代码来检查以下断言:
-
arg.shape[0] >= 1
-
arg.shape[1] == arg.shape[0]
-
arg.shape[2] % 2 == 0
-
arg.shape[2] // 2 >= 1
例如,这是在对形状为(3, 3, 5)
的参数调用导出函数时得到的错误:
>>> def f(x): # x: f32[b, b, 2*d]
... return x
>>> exp = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.
这些错误出现在编译之前的预处理步骤中。
部分支持符号维度的除法
JAX 将尝试简化除法和取模运算,例如(a * b + a) // (b + 1) == a
和6*a + 4 % 3 == 1
。特别地,JAX 会处理以下情况:要么(a)没有余数,要么(b)除数是一个常数,此时可能有一个常数余数。
例如,尝试计算reshape
操作的推断维度时,以下代码会导致除法错误:
>>> b, = export.symbolic_shape("b")
>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
Traceback (most recent call last):
jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1).
The remainder mod(b, - 2) should be 0.
注意以下操作将成功:
>>> b, = export.symbolic_shape("b")
>>> # We specify that the first dimension is a multiple of 4
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((4*b,), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,2*b]),)
>>> # We specify that some other dimension is even
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32))
>>> exp.out_avals
(ShapedArray(int32[2,15*b]),)
与 TensorFlow 的互操作
参见JAX2TF 文档。
JAX 错误
此页面列出了在使用 JAX 时可能遇到的一些错误,以及如何修复它们的代表性示例。
class jax.errors.ConcretizationTypeError(tracer, context='')
当 JAX 追踪器对象在需要具体值的上下文中使用时(参见关于 Tracer 是什么的更多信息),会发生此错误。在某些情况下,可以通过将问题值标记为静态来轻松修复;在其他情况下,可能表明您的程序正在执行 JAX JIT 编译模型不直接支持的操作。
例子:
在期望静态值的位置使用跟踪值
导致此错误的一个常见原因是在需要静态值的位置使用跟踪值。例如:
>>> from functools import partial
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
... return x.min(axis)
>>> func(jnp.arange(4), 0)
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().
通常可以通过将问题参数标记为静态来解决此问题:
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
... return x.min(axis)
>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
形状依赖于跟踪的值
在 JIT 编译的计算中,如果形状依赖于跟踪数量中的值时,也可能出现此类错误。例如:
>>> @jit
... def func(x):
... return jnp.where(x < 0)
>>> func(jnp.arange(4))
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.
这是一个与 JAX JIT 编译模型不兼容的操作示例,该模型要求在编译时知道数组大小。这里返回的数组大小取决于 x 的内容,这样的代码不能 JIT 编译。
在许多情况下,可以通过修改函数中使用的逻辑来解决此问题;例如,这里是一个类似问题的代码:
>>> @jit
... def func(x):
... indices = jnp.where(x > 1)
... return x[indices].sum()
>>> func(jnp.arange(4))
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.
以下是如何以避免创建动态大小索引数组的方式表达相同操作的示例:
>>> @jit
... def func(x):
... return jnp.where(x > 1, x, 0).sum()
>>> func(jnp.arange(4))
Array(5, dtype=int32)
要了解与跟踪器与常规值,具体与抽象值相关的更多细微差别,可以阅读有关不同类型的 JAX 值的内容。
参数:
-
追踪器 (core.Tracer)
-
上下文 (str)
class jax.errors.KeyReuseError(message)
当 PRNG 密钥以不安全的方式重复使用时,会发生此错误。仅在设置 jax_debug_key_reuse
为 True 时检查密钥重复使用。
以下是导致此类错误的代码简单示例:
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... value = jax.random.uniform(key)
... new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
此类密钥重用存在问题,因为 JAX PRNG 是无状态的,必须手动分割密钥;有关更多信息,请参见 Sharp Bits: Random Numbers。
参数:
消息 (str)
class jax.errors.NonConcreteBooleanIndexError(tracer)
当程序尝试在跟踪索引操作中使用非具体布尔索引时,会发生此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此布尔掩码必须小心使用。某些逻辑通过布尔掩码实现可能在 jax.jit()
函数中根本不可能;在其他情况下,可以使用 where()
的三参数版本以 JIT 兼容的方式重新表达逻辑。
以下是可能导致此错误的几个示例。
通过布尔掩码构建数组
在尝试在 JIT 上下文中通过布尔遮罩创建数组时最常见出现此错误。例如:
>>> import jax
>>> import jax.numpy as jnp
>>> @jax.jit
... def positive_values(x):
... return x[x > 0]
>>> positive_values(jnp.arange(-5, 5))
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
此函数试图仅返回输入数组中的正值;除非将 x 标记为静态,否则在编译时无法确定返回数组的大小,因此无法在 JIT 编译下执行此类操作。
可重新表达的布尔逻辑
尽管不直接支持创建动态大小的数组,但在许多情况下可以重新表达计算逻辑以符合 JIT 兼容的操作。例如,以下是另一个因相同原因在 JIT 下失败的函数:
>>> @jax.jit
... def sum_of_positive(x):
... return x[x > 0].sum()
>>> sum_of_positive(jnp.arange(-5, 5))
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
然而,在这种情况下,有问题的数组仅是一个中间值,我们可以使用支持 JIT 的三参数版本的 jax.numpy.where()
表达相同的逻辑:
>>> @jax.jit
... def sum_of_positive(x):
... return jnp.where(x > 0, x, 0).sum()
>>> sum_of_positive(jnp.arange(-5, 5))
Array(10, dtype=int32)
将布尔遮罩替换为带有三个参数的 where()
的模式是解决这类问题的常见方法。
对 JAX 数组进行布尔索引
另一个经常出现此错误的情况是使用布尔索引,例如 .at[...].set(...)
。以下是一个简单的示例:
>>> @jax.jit
... def manual_clip(x):
... return x.at[x < 0].set(0)
>>> manual_clip(jnp.arange(-2, 2))
Traceback (most recent call last):
...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
此函数试图将小于零的值设置为标量填充值。与上述类似,可以通过在 where()
中重新表达逻辑来解决此问题:
>>> @jax.jit
... def manual_clip(x):
... return jnp.where(x < 0, 0, x)
>>> manual_clip(jnp.arange(-2, 2))
Array([0, 0, 0, 1], dtype=int32)
参数:
tracer (core.Tracer)
class jax.errors.TracerArrayConversionError(tracer)
当程序尝试将 JAX 追踪对象转换为标准的 NumPy 数组时会发生此错误(详见不同类型的 JAX 值,了解追踪器的更多信息)。通常情况下会发生在几种情况之一。
在 JAX 变换中使用非 JAX 函数
如果尝试在 JAX 变换(jit()
、grad()
、jax.vmap()
等)内部使用非 JAX 库如 numpy
或 scipy
,则可能会导致此错误。例如:
>>> from jax import jit
>>> import numpy as np
>>> @jit
... def func(x):
... return np.sin(x)
>>> func(np.arange(4))
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[4]
在这种情况下,你可以通过使用 jax.numpy.sin()
替换 numpy.sin()
来解决问题:
>>> import jax.numpy as jnp
>>> @jit
... def func(x):
... return jnp.sin(x)
>>> func(jnp.arange(4))
Array([0\. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
另请参阅 External Callbacks 了解从转换的 JAX 代码返回到主机端计算的选项。
使用追踪器索引 numpy 数组
如果此错误出现在涉及数组索引的行上,则可能是被索引的数组 x
是标准的 numpy.ndarray,而索引 idx
是追踪的 JAX 数组。例如:
>>> x = np.arange(10)
>>> @jit
... def func(i):
... return x[i]
>>> func(0)
Traceback (most recent call last):
...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[0]
根据上下文,你可以通过将 numpy 数组转换为 JAX 数组来解决此问题:
>>> @jit
... def func(i):
... return jnp.asarray(x)[i]
>>> func(0)
Array(0, dtype=int32)
或者通过将索引声明为静态参数:
>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
... return x[i]
>>> func(0)
Array(0, dtype=int32)
要了解关于追踪器与常规值、具体值与抽象值的更多微妙之处,可以阅读有关不同类型的 JAX 值。
参数:
tracer (core.Tracer)
class jax.errors.TracerBoolConversionError(tracer)
当在期望布尔值的上下文中使用 JAX 中的追踪值时会出现此错误(详见不同类型的 JAX 值,了解追踪器的更多信息)。
布尔转换可以是显式的(例如bool(x)
)或隐式的,通过控制流的使用(例如if x > 0
或while x
)、使用 Python 布尔运算符(例如z = x and y
、z = x or y
、z = not x
)或使用它们的函数(例如z = max(x, y)
、z = min(x, y)
等)。
在某些情况下,通过将跟踪值标记为静态,可以轻松解决此问题;在其他情况下,这可能表明您的程序正在执行 JAX JIT 编译模型不直接支持的操作。
示例:
在控制流中使用跟踪值
一个经常出现这种情况的案例是,当跟踪值用于 Python 控制流时。例如:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
... return x if x.sum() < y.sum() else y
>>> func(jnp.ones(4), jnp.zeros(4))
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
我们可以将输入的x
和y
都标记为静态,但这样做将破坏在这里使用jax.jit()
的目的。另一个选择是将 if 语句重新表达为三项jax.numpy.where()
:
>>> @jit
... def func(x, y):
... return jnp.where(x.sum() < y.sum(), x, y)
>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)
对于包括循环在内的更复杂的控制流,请参阅控制流运算符。
跟踪值在控制流中的使用
另一个常见的错误原因是,如果您无意中在布尔标志上进行跟踪。例如:
>>> @jit
... def func(x, normalize=True):
... if normalize:
... return x / x.sum()
... return x
>>> func(jnp.arange(5), True)
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这里,因为标志normalize
被跟踪,所以不能在 Python 控制流中使用它。在这种情况下,最好的解决方案可能是将此值标记为静态:
>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
... if normalize:
... return x / x.sum()
... return x
>>> func(jnp.arange(5), True)
Array([0\. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
有关static_argnums
的更多信息,请参阅jax.jit()
的文档。
使用非 JAX 感知的函数
另一个常见的错误原因是在 JAX 代码中使用非 JAX 感知的函数。例如:
>>> @jit
... def func(x):
... return min(x, 0)
>>> func(2)
Traceback (most recent call last):
...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这种情况下,错误是因为 Python 的内置min
函数与 JAX 变换不兼容。可以通过将其替换为jnp.minimum
来修复这个问题:
>>> @jit
... def func(x):
... return jnp.minimum(x, 0)
>>> print(func(2))
0
要更深入了解关于跟踪器与常规值、具体值与抽象值之间的微妙差别,您可能需要阅读关于不同类型 JAX 值的文档。
参数:
tracer(core.Tracer)
class jax.errors.TracerIntegerConversionError(tracer)
如果在期望 Python 整数的上下文中使用 JAX Tracer 对象,则可能会出现此错误(有关 Tracer 是什么的更多信息,请参阅关于不同类型 JAX 值的内容)。它通常发生在几种情况下。
将跟踪器放在整数位置
如果您试图将跟踪值传递给需要静态整数参数的函数,则可能会出现此错误;例如:
>>> from jax import jit
>>> import numpy as np
>>> @jit
... def func(x, axis):
... return np.split(x, 2, axis)
>>> func(np.arange(4), 0)
Traceback (most recent call last):
...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]
当出现这种情况时,解决方案通常是将有问题的参数标记为静态:
>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
... return np.split(x, 2, axis)
>>> func(np.arange(10), 0)
[Array([0, 1, 2, 3, 4], dtype=int32),
Array([5, 6, 7, 8, 9], dtype=int32)]
另一种方法是将转换应用于封装要保护参数的闭包,可以手动执行如下或使用functools.partial()
:
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,这也是为什么首选static_argnums
的原因。
使用跟踪器索引列表
如果您尝试使用跟踪的量索引 Python 列表,则可能会出现此错误。例如:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> L = [1, 2, 3]
>>> @jit
... def func(i):
... return L[i]
>>> func(0)
Traceback (most recent call last):
...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]
根据上下文,通常可以通过将列表转换为 JAX 数组来解决此问题:
>>> @jit
... def func(i):
... return jnp.array(L)[i]
>>> func(0)
Array(1, dtype=int32)
或者通过将索引声明为静态参数来声明:
>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
... return L[i]
>>> func(0)
Array(1, dtype=int32, weak_type=True)
要更深入理解跟踪器与常规值以及具体与抽象值之间的微妙差别,您可以阅读有关不同类型 JAX 值的文档。
参数:
tracer(core.Tracer)
class jax.errors.UnexpectedTracerError(msg)
当您使用从函数中泄漏出来的 JAX 值时,会出现此错误。泄漏值是什么意思?如果您对函数f
应用 JAX 转换,并在f
外某个作用域存储了一个中间值的引用,那么该值被视为已泄漏。泄漏值是副作用。(阅读更多关于避免副作用的内容,请参阅Pure Functions)
JAX 在你稍后在另一个操作中使用泄露的值时检测到泄漏,此时会引发UnexpectedTracerError
。要修复此问题,请避免副作用:如果一个函数计算了外部作用域需要的值,则需要明确从转换后的函数中返回该值。
具体来说,Tracer
是 JAX 在转换期间函数中间值的内部表示,例如在jit()
、pmap()
、vmap()
等内部。在转换之外遇到Tracer
表示泄漏。
泄漏值的生命周期
请考虑以下转换函数的示例,它将一个值泄漏到外部作用域:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> outs = []
>>> @jit # 1
... def side_effecting(x):
... y = x + 1 # 3
... outs.append(y) # 4
>>> x = 1
>>> side_effecting(x) # 2
>>> outs[0] + 1 # 5
Traceback (most recent call last):
...
UnexpectedTracerError: Encountered an unexpected tracer.
在此示例中,我们从内部转换作用域泄漏了一个跟踪值到外部作用域。当使用泄漏值而不是泄漏值时,会出现UnexpectedTracerError
。
此示例还展示了泄漏值的生命周期:
- 函数被转换了(在本例中,通过
jit()
)。- 调用了转换后的函数(启动函数的抽象跟踪,并将
x
转换为Tracer
)。- 中间值
y
被创建,稍后将被泄漏(跟踪函数的中间值也是Tracer
)。- 该值已泄漏(通过外部作用域的一个侧通道将其追加到列表中逃逸函数)
- 使用了泄漏的值,并引发了 UnexpectedTracerError。
UnexpectedTracerError 消息试图通过包含有关每个阶段信息的方法来指出代码中的这些位置。依次:
- 转换后函数的名称(
side_effecting
)以及触发跟踪的转换名称jit()
)。- 泄漏的 Tracer 创建时的重构堆栈跟踪,包括调用转换后函数的位置。(
When the Tracer was created, the final 5 stack frames were...
)。- 从重构的堆栈跟踪中,创建泄漏 Tracer 的代码行。
- 错误消息中不包括泄漏位置,因为难以确定!JAX 只能告诉你泄漏值的外观(其形状和创建位置)以及泄漏的边界(变换的名称和转换后函数的名称)。
- 当前错误的堆栈跟踪指向值的使用位置。
可以通过将值从转换函数返回来修复错误:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> outs = []
>>> @jit
... def not_side_effecting(x):
... y = x+1
... return y
>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1 # all good! no longer a leaked value.
Array(3, dtype=int32, weak_type=True)
泄漏检查器
如上述第 2 和第 3 点所讨论的那样,JAX 显示了一个重建的堆栈跟踪,指出了泄露值的创建位置。这是因为 JAX 仅在使用泄露值时才会引发错误,而不是在值泄漏时。这不是引发此错误的最有用的地方,因为您需要知道泄露跟踪器的位置来修复错误。
为了更容易跟踪此位置,您可以使用泄漏检查器。当启用泄漏检查器时,一旦泄露了Tracer
,就会引发错误。(更确切地说,在从中泄漏Tracer
的转换函数返回时会引发错误)
要启用泄漏检查器,可以使用JAX_CHECK_TRACER_LEAKS
环境变量或with jax.checking_leaks()
上下文管理器。
注意
请注意,此工具属于实验性质,可能会报告错误的情况。它通过禁用某些 JAX 缓存工作,因此会对性能产生负面影响,应仅在调试时使用。
示例用法:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> outs = []
>>> @jit
... def side_effecting(x):
... y = x+1
... outs.append(y)
>>> x = 1
>>> with jax.checking_leaks():
... y = side_effecting(x)
Traceback (most recent call last):
...
Exception: Leaked Trace
参数:
msg (str)
转移保护
JAX 可能在类型转换和输入分片期间在主机和设备之间传输数据。为了记录或阻止任何意外的转移,用户可以配置 JAX 转移保护。
JAX 转移保护区分两种类型的转移:
-
显式转移:
jax.device_put*()
和jax.device_get()
调用。 -
隐式转移:其他转移(例如打印
DeviceArray
)。
转移保护可以根据其保护级别采取行动:
-
"allow"
: 静默允许所有转移(默认)。 -
"log"
: 记录并允许隐式转移。静默允许显式转移。 -
"disallow"
: 禁止隐式转移。静默允许显式转移。 -
"log_explicit"
: 记录并允许所有转移。 -
"disallow_explicit"
: 禁止所有转移。
当禁止转移时,JAX 将引发 RuntimeError
。
转移保护使用标准的 JAX 配置系统:
-
一个
--jax_transfer_guard=GUARD_LEVEL
命令行标志和jax.config.update("jax_transfer_guard", GUARD_LEVEL)
将设置全局选项。 -
一个
with jax.transfer_guard(GUARD_LEVEL): ...
上下文管理器将在上下文管理器的作用域内设置线程局部选项。
注意,类似于其他 JAX 配置选项,新生成的线程将使用全局选项,而不是生成线程所在作用域的任何活动线程局部选项。
转移保护还可以根据转移方向更为选择性地应用。标志和上下文管理器名称以相应的转移方向作为后缀(例如 --jax_transfer_guard_host_to_device
和 jax.config.transfer_guard_host_to_device
):
-
"host_to_device"
: 将 Python 值或 NumPy 数组转换为 JAX 设备上的缓冲区。 -
"device_to_device"
: 将 JAX 设备缓冲区复制到另一个设备。 -
"device_to_host"
: 从 JAX 设备缓冲区获取数据。
获取 CPU 设备上的缓冲区始终允许,无论转移保护级别如何。
下面展示了使用转移保护的示例。
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x) # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
... print("x", x) # x has already been fetched into the host.
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
... try:
... print("z", z) # Implicit transfers are disallowed.
... assert False, "This line is expected to be unreachable."
... except:
... print("z could not be fetched")
x 1
y 2
z could not be fetched
Pallas:一个 JAX 内核语言
Pallas 是 JAX 的扩展,允许为 GPU 和 TPU 编写自定义内核。本节包含使用 Pallas 的教程、指南和示例。
指南
-
Pallas 设计
-
介绍
-
Pallas:为内核扩展 JAX
-
-
Pallas 快速入门
-
在 Pallas 中的 Hello world
-
Pallas 编程模型
-
平台特性
-
Pallas TPU
-
使用 Pallas 编写 TPU 内核
-
流水线和
BlockSpec
-
Pallas 设计
在这份文档中,我们解释了初始的 Pallas 设计。这是一些早期设计决策的快照,并且 Pallas 的特定 API 可能已经发生了变化。
Introduction
JAX 被用于各种工作负载,从大规模机器学习到科学计算。JAX 的成功故事也是 XLA 的成功故事,XLA 是 JAX 的主要编译器目标——XLA 为加速器编译 JAX 程序,并使 JAX 能够扩展到最大的 ML 模型。JAX 描述了在 XLA 表示 HLO 中的逻辑计算。HLO 描述了逻辑上的计算过程,但不涉及物理执行。对于广泛的 ML 应用,XLA 在编译用户程序方面表现良好,但不可避免地,一些用户会遇到 XLA 的限制。在这些情况下,我们需要提供一个“逃生通道”,让专家编写手动调优的内核,以在那个时刻超越 XLA 的性能。此外,ML 系统研究的进展需要一些时间才能被整合到 XLA 中,而用户通常希望提前使用这些优化。随着时间的推移,编译器可以通过手动调优的内核整合已经通过实验验证的优化。
XLA 确实提供了CustomCall
机制作为一种逃生口,但这需要用户编写 C++代码,在 GPU 上还需要用户了解 CUDA 编程模型。CUDA 编程模型对于许多机器学习 GPU 核心(如矩阵乘法或多头注意力)来说可能过于低级,即使是专家用户也会在使用 CUDA 来实现高效的矩阵乘法或多头注意力时遇到困难。此外,JAX 用户通常熟悉 Python 和类似 NumPy 的数组编程,不涉及编写任何 C++代码或考虑 GPU 并行性。所有流行的机器学习框架都共享这一思想:通过高级操作(如matmul
或convolution
)来操作(通常是)数组。不幸的是,这意味着通过CustomCall
实现自定义操作是一项重大投资,可能需要学习 C++和/或 GPU 编程。
Triton,由 OpenAI 构建和维护的 GPU 编译器,在 ML 编译器领域引起了轰动。Triton 提供了最佳的双赢方案:用于 GPU 核心的基于数组的编程模型。Triton 是 PyTorch 2.0 中torch.compile
的主要代码生成路径,通过 Torch Inductor 库。Triton 积极地在更高级的表示上隐藏了 GPU 编程的某些方面,以更易于访问的编程模型从 Python 中生成优化的代码。虽然 GPU 比 Triton 提供的更加灵活,但在 ML 领域,Triton 似乎对许多应用程序来说已经足够表达力。
在本文档中,我们描述了 Pallas,这是 JAX 的一个扩展,可以使用类似 Triton 的模型为 GPU 和 TPU 编写核心程序。基于 JAX 的核心语言具有几个优点:
-
虽然 Triton 向用户公开了类似 TPU 的编程模型,即在 L1-cache 的数组块上编写程序,但它足够专业以至于我们不能直接为 TPU 编译 Triton。例如,Triton 提供了专门用于处理并行写入的原子操作,这在 TPU 上并不一定有意义。一个更高级的前端可以将平台的细节抽象化,只显示基于瓦片的编程模型。这样,核心将在不同的硬件平台上可移植。
-
作为基于跟踪的数值计算的前端,JAX 既成熟又广泛使用。通过将核心编程语言嵌入到 JAX 本身中,我们可以重用 JAX 的跟踪基础设施,并提供一个类似 NumPy 的前端,这对用户来说已经很熟悉。
-
JAX 转换是其成功的关键,允许用户表达简单的程序,但通过转换实现复杂的功能。我们可以利用相同的转换(vmap、jvp 等)来转换用户编写的核心。
一个开放的问题是:JAX 真的适合作为核心语言吗?我们认为是的。Triton 表明,一个数组编程语言可以实际用于编写 GPU 核心,而 JAX 正是如此。JAX 还被证明是编译器和程序转换的灵活前端。
我们描述 Pallas 如下:首先描述我们如何扩展 JAX 以支持编写自定义核心。然后展示如何将 Pallas 降低到 Triton 和 Mosaic。最后描述通过 JAX 转换转换 Pallas 核心的现有和潜在方法。
Pallas 降低路径的可视化
Pallas:为核心扩展 JAX
我们想要强调的关键点是,Pallas 只是 JAX,附加了一些扩展:
-
用户现在在他们的 JAX 代码中使用称为
Ref
的引用类型。这使得用户在 JAX 中更加精确地控制内存访问和布局,其物理布局将更加接近。 -
用户使用 JAX 原语的子集以及一组特定于 Pallas 的原语编写他们的 JAX 程序。
-
用户通过特殊的
pallas_call
高阶函数将他们的 Pallas 核心嵌入到外部 JAX 程序中,该函数在映射中执行核心。它类似于pmap
或shard_map
,但涉及共享内存的引用。
我们将逐个通过示例讨论这三个扩展。
请注意,这些 API 仍处于实验阶段,可能会发生变化。
引用类型
让我们看一个添加两个向量的示例 Pallas 程序:
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)
与常规的 JAX 程序不同,add_kernel
不接收不可变的数组参数。相反,它提供了可以使用类似 NumPy 的语法从中读取和原地更新的引用。Ref
不是 Pallas 特定的概念 - 它们被引入 JAX 来表示有状态的计算。然而,我们在编写操作可变内存的核心时可以利用它们。
Pallas 核心不仅接收与核心输入对应的Ref
,还接收作为输出的Ref
(通过pallas_call
中的out_shape
指定)。Ref
是一种特殊类型,不能直接传递给 JAX 常规的一组原语而不先读取。从Ref
中读取后,您会得到一个 JAX Array
类型,并且您必须将一个Array
写入Ref
。
从/写入 Refs
从Ref
中读取对应于将数组加载到内存层次结构的最低级别(在 GPU 上是 L1 缓存,在 TPU 上是向量寄存器)。写入Ref
类似。
def f(x_ref, o_ref):
# Using vanilla Python indexing
x = x_ref[0, 2:5, :]
# Or via Numpy advanced int indexing
o_ref[jnp.arange(3), :] = x
# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
# Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]
可以通过类似的__setitem__
样式索引来写入Ref
。
其他形式的索引(例如动态切片)可以通过pallas.load
和pallas.store
来完成,这是设计用于更轻松地从/存储到内存的新 JAX 原语。稍后我们将讨论这些新原语。
用新的 Pallas 原语扩展 JAX
因为 JAX 是以 HLO 为目标设计的,其一组原语紧密地反映了 HLO 操作的一组。针对新的编译器(例如 Triton 或 Mosaic),意味着我们可能需要用新的特定于新编译器的原语补充 JAX 的原语。同时,我们可能无法将所有 JAX 原语降低到新编译器,因此我们需要将其限制为一个子集。
因为 Pallas 最初是以 Triton 为目标设计的,我们提供了一组新的原语,目标是 Triton 编程模型。正如我们稍后将展示的,我们也可以将这些原语降低到 Mosaic。
pallas.load
和pallas.store
pallas.load
和pallas.store
是允许从内存加载和存储到内存的原语。与__getitem__
和__setitem__
不同,它们更灵活,但更冗长。具体来说,您可以使用pallas.dynamic_slice
(简称pallas.ds
)构造(可能应该上游到 JAX,以便与Ref
的__getitem__
和__setitem__
一起使用)。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
# Using integer indexing automatically broadcasts
x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
# You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)
pallas.load
和pallas.store
还支持通过掩码参数进行屏蔽。
def f(x_ref, o_ref):
# Reading from memory via pallas.load
idx = jnp.arange(8)
mask = idx < 5
x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))
当进行越界加载/存储时,屏蔽是很重要的。屏蔽的操作语义可以由编译器决定(如果我们正确理解文档的话,Triton 在掩码时避免从内存读取/写入)。
pallas.program_id
和pallas.num_programs
正如我们将很快看到的,我们将多次执行相同的 Pallas 核心(根据后端是并行还是管道)。这些新原语告诉我们“我们”在核心执行中的“位置”。
pallas.program_id
接受一个轴参数,告诉我们在多维网格的轴上,此内核当前正在执行的索引(类似于 CUDA 编程中的threadId
或jax.pmap
中的lax.axis_index
)。请注意,我们目前借用了 Triton 的“program”术语,将来可能会改为对 JAX 用户更为熟悉的术语。
def f(x_ref, o_ref):
i = pl.program_id(axis=0) # execution index in the first axis of the grid
o_ref[i] = jnp.exp(x_ref[i])
pallas.num_programs
还接受一个轴参数,并返回该轴的网格大小。
注意,虽然program_id
和num_programs
是 Triton 特有的术语,但也很容易推广到 TPU 上。
在 Pallas 中使用 JAX 原语的子集
因为我们正在编写内核,而不是高级的 HLO 程序,一些 JAX 原语可能无法高效地在我们的底层基础设施中表示。但是,我们知道我们可以支持大多数逐元素操作、简单的点积和 JAX 控制流。
虽然我们还没有完全列出我们可以在 Pallas 内核中支持的所有 JAX 原语,但我们当然可以确定一些不易降级或不太可能有用的原语:
-
conv_general
- 卷积通常不作为底层硬件的原语提供。 -
gather/scatter
- 底层编译器可能不支持非连续内存读写。
使用pallas_call
执行 Pallas 内核
现在我们已经编写了我们的 Pallas 内核(也就是带有Ref
和额外 Pallas 原语的 JAX),我们如何在 GPU 或 TPU 上执行它们呢?我们使用pallas_call
,这是一个高阶函数(类似于jax.jit
和jax.pmap
),用于执行内核。
pallas_call
的签名如下:
def pallas_call(
kernel: Callable,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
out_shapes: Sequence[jax.ShapeDtypeStruct],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
当我们向pallas_call
提供内核时,我们提供了额外的信息。首先是out_shape
,它告诉内核输出的形状(pallas_call
将传递一个对应的Ref
给内核以进行写入)。其余信息(in_specs
、out_specs
和grid
)是关于内核如何在加速器上调度的信息。
pallas_call
的(粗略)语义如下:
def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
def execute(*args):
outputs = map(empty_ref, out_shapes)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
zip(args, in_specs)]
local_outputs = [out_spec.transform(arg, indices) for arg, out_spec in
zip(outputs, out_specs)]
kernel(*local_inputs, *local_outputs) # writes to outputs
return execute
具体来说,pallas_call
将“循环”遍历网格迭代空间,对通过in_specs
和out_specs
指定的输入和输出应用变换。在每次迭代中,内核将在变换后的输入和输出上调用。请注意,“循环”遍历迭代空间可以并行执行(例如在 GPU 上)。pallas_call
还不保证循环迭代空间的顺序,只保证会循环遍历迭代空间的每个成员。像 Triton 和 Mosaic 这样的编译器将具有与网格相关的更具体的操作语义。
变换函数
pallas_call
的in_specs
和out_specs
参数允许以某种方式转换输入和输出。Pallas 目前提供的两个选项是恒等变换(其中输入和输出保持不变)和BlockSpec
,它通过循环索引确定Ref
的固定大小切片。
BlockSpec
接受一个index_map
函数和一个block_shape
。从逻辑上讲,它接受一个数组,并沿着每个轴将其切片成block_shape
大小的块。index_map
函数接受循环索引(从网格索引集)并将其映射到块索引。转换函数将Ref
转换为对应块的Ref
的逻辑视图。当我们在block_shape
的条目中指定None
时,这对应于在内核中从该维度中“映射”掉它。
class BlockSpec:
index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
block_shape: Tuple[Optional[int], ...]
def transform(self, ref, *loop_indices):
block_indices = self.transform_function(loop_indices)
# Returns a view of `ref` starting at `block_indices` of shape self.block_shape
...
我们还可以想象其他与pallas_call
一起使用的Spec
,例如对应于重叠窗口的Spec
,以实现卷积等功能。
Pallas 作为前端的直接好处
通过为内核编写提供 JAX 前端,我们可以立即获得一些好处。
更灵活的前端
第一点是,JAX 用户已经习惯于使用 JAX 及其基于追踪的转换的好处(和局限性)。这意味着用户在编写 Pallas 内核时可以使用闭包和其他熟悉的 Python 构造。这与现有基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器不同。例如,这使得 Pallas 比 Triton 更适合模板化。
请看这个示例,演示了我们如何在 Python 中使用高阶函数来为内核模板化。
def make_kernel(eltwise_kernel):
def add(x_ref, y_ref, o_ref):
x = pl.load(x_ref, ())
y = pl.load(y_ref, ())
pl.store(o_ref, (), eltwise_kernel(x + y))
return add
kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)
pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
模拟模式
通过将内核表示为具有 JAX 原语和一些新的 Pallas 原语的程序,我们还可以直接将 Pallas 程序降级为 StableHLO 并使用 XLA 进行编译/执行。具体来说,pallas_call
可以实现为对网格的lax.scan
。这使我们能够在任何 XLA 支持的平台上(甚至是 CPU!)开发 GPU 或 TPU 内核,并使用 JAX/XLA 调试工具(如jax.debug.print
)调试它们。我们还可以使用更可靠和更好测试的 XLA 数值来验证 Triton 和 Mosaic 编译器的正确性。人们还可以想象通过扰动scan
排序来模拟 GPU 上发生的并行读写。
例子
add
我们修改我们的add_kernel
示例,使用BlockSpec
操作(2,)-大小的块。
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=[
pl.BlockSpec(lambda i: i, (2,)),
pl.BlockSpec(lambda i: i, (2,))
],
out_specs=pl.BlockSpec(lambda i: i, (2,)),
grid=(4,))
add(x, y)
模板化的矩阵乘法
在这个示例中,我们通过对输入数组的行和列的块进行展开累加来计算输出的瓦片。我们通过高阶函数将激活函数内联到内核体中,以便我们可以发出一个融合内核。
def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
for k in range(x_ref.shape[1] // block_k):
x = x_ref[:, k*block_k:(k+1)*block_k]
y = y_ref[k*block_k:(k+1)*block_k, :]
acc += x @ y
o_ref[:, :] = activation(acc).astype(o_ref.dtype)
x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128
@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
block_m, block_n, block_k = block_shape
fused_matmul = pl.pallas_call(
partial(matmul_kernel, block_k=block_k, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
grid=(4, 4),
)
return fused_matmul(x, y)
z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)
将 Pallas 降级
用户表达其 Pallas 内核后,我们根据目标后端将其降级到不同的表示形式。在 GPU 上,我们将 Pallas 降级为 Triton IR,在 TPU 上,我们将 Pallas 降级为 Mosaic。
将 Pallas 降级到 Triton 以适配 GPU
降低 Pallas 到 Triton 是容易的,因为 Pallas 设计时就以 Triton 为目标语言。Pallas 和 Triton 主要的区别在于 Triton 没有 BlockSpec
的概念,且在内存加载和存储时使用指针而不是索引。
Triton 支持指针作为其语言中的数组元素类型,在 Triton 中可以从数组中加载和存储指针。在 Pallas 中,给定一个 (4, 5)
形状的 Ref
,x_ref
,然后执行 x_ref[3, 2]
类似操作时,我们需要将其降级为计算 x_ref
中适当行主位置的 Triton 指针(即执行 5 * 3 + 2 * 1)。类似地,当我们将切片降级到 Triton 时,例如 x_ref[4, :]
,我们需要生成一个指针数组 5 * 4 + jnp.arange(3)
。
除此之外,将 Pallas 降级到 Triton 相当直接。JAX 的点积可以降级为 Triton 的点积,JAX 的一元原语则降级为它们的 Triton 等价物。Triton 的原子操作通过新的 Pallas 原子原语降级。
将 Pallas 降级到 Mosaic 适用于 TPU
Mosaic 主要消耗标准的 MLIR 方言,并生成供 TPU 编译的 LLO。Pallas 可以通过将 JAX 原语翻译为 MLIR(主要是 vector
和 arith
方言)来降级到 Mosaic。BlockSpec
可以转换为流水线调度(即 Mosaic 中的 transform_func
)。
转换 Pallas
一个自然的问题是 JAX 变换如何与 Pallas 内核交互?主要有两种方式:Pallas 内核内部的变换和 Pallas 内核外部的变换。
Pallas 内核内部的转换实际上“应该只是工作”,只要我们能够降低变换后的代码。例如,我们可以在 JAX 内核中使用 jax.grad(jnp.sin)(...)
,因为我们可以将 cos
降低到 Triton 和 Mosaic。然而,我们可能无法将 jax.vmap(lax.dynamic_slice)
降低,因为它可能转变为我们无法降级的 gather 操作。
从外部 JAX 程序转换 Pallas 内核可能是更有趣的情况。我们如何处理像 vmap(pallas_call)
和 grad(pallas_call)
这样的事情?
vmap-of-pallas_call
vmap
自动将 JAX 程序向量化。虽然内核编写者可能希望精确控制批处理内核与非批处理变体之间的行为差异,但我们可以为 pallas_call
提供合理的默认 vmap
规则,同时提供 jax.custom_vmap
定制机制。当对 pallas_call
进行 vmap
操作时,我们会增加一个额外的网格维度,对应新的批处理维度,并转换 BlockSpec
以处理沿该维度的索引。
grad-of-pallas_call
pallas_call
的grad
使得内核的自动微分成为可能。jax.grad
可以分解为三个不同变换的应用:jvp
、partial_eval
和transpose
。原则上,在为pallas_call
实现这些规则时,我们可以重用大部分 JAX 的基础设施(因为它的行为与现有的 JAX 高阶原语类似)。
然而,内核的自动微分可能会因内存访问的转置方式而导致性能下降。如果我们编写一个具有重叠和并行读取以及不相交但并行写入的 GPU 内核,则会自动将其转置为一个具有重叠但并行写入的内核(当以原子方式执行时速度较慢),并且具有不相交但并行读取。为了生成更好地利用共享内存并行性的内核,我们需要重新排序循环并更改内核的向量化方式。不幸的是,在Pallas
中我们没有一个适合这种操作表示的程序。自动区分内核的一个潜在方向是有效地探索不同的表示形式,也许像Dex
中的表示形式那样。我们还可以看看Enzyme
如何解决这个问题。然而,对于能够有效进行转置的内核类别来说,Pallas
内核的自动微分可能仍然是有用的(例如逐元素内核)。
总的来说,jax.custom_vjp
是一种可行的逃生口,用来表达与jax.grad
一起工作的Pallas
内核。
其他转换
我们可以想象其他适用于Pallas
内核的 JAX 转换,这些转换我们尚未明确探索。例如,checkify
是一种进行功能性错误处理的 JAX 转换。我们可以想象使用checkify
与pallas_call
结合使用,以便从 GPU 内核中传递出错误代码,指示是否产生了 OOB 访问或 NaN。
另一个与之集成的潜在转换是custom_partitioning
,以便使可自动分区的内核可以与pjit
一起使用。
Pallas 快速入门
Pallas 是 JAX 的扩展,允许为 GPU 和 TPU 编写自定义核函数。Pallas 允许您使用相同的 JAX 函数和 API,但在抽象层面上操作更低。
具体来说,Pallas 要求用户考虑内存访问以及如何在硬件加速器的多个计算单元之间分割计算。在 GPU 上,Pallas 降级为 Triton,在 TPU 上,Pallas 降级为 Mosaic。
让我们深入一些例子。
注意:Pallas 仍然是一个实验性 API,可能会因更改而破坏代码!
在 Pallas 中的 hello world
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
首先,我们在 Pallas 中编写“hello world”,这是一个将两个向量相加的核函数。
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
Ref
类型
让我们稍微解析一下这个函数。与您可能编写过的大多数 JAX 函数不同,它不以 jax.Array
作为输入,也不返回任何值。相反,它以 Ref
对象作为输入。请注意,我们也没有任何输出,但我们有一个 o_ref
,它对应于所需的输出。
从 Ref
读取
在函数体中,我们首先从 x_ref
和 y_ref
中读取,用 [...]
表示(省略号表示我们正在读取整个 Ref
;或者我们也可以使用 x_ref[:]
)。像这样从 Ref
中读取返回一个 jax.Array
。
向 Ref
写入
然后我们将 x + y
写入 o_ref
。在 JAX 中历史上并不支持突变 - jax.Array
是不可变的!Ref
是新的(实验性)类型,在某些情况下允许突变。我们可以理解为向 Ref
写入是对其底层缓冲区的突变。
因此,我们编写了一个我们称之为“核函数”的程序,定义为在加速器上作为执行的原子单位运行,而不与主机进行任何交互。我们如何从 JAX 计算中调用它呢?我们使用 pallas_call
高阶函数。
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
pallas_call
将 Pallas 核函数提升为可以作为较大 JAX 程序的一部分调用的操作。但是,为了做到这一点,它需要一些额外的细节。在这里,我们指定 out_shape
,一个具有 .shape
和 .dtype
(或列表)的对象。out_shape
决定了我们在 add_vector_kernel
中的 o_ref
的形状/数据类型。
pallas_call
返回一个函数,该函数接受并返回 jax.Array
。
这里实际上发生了什么?
到目前为止,我们已经描述了如何思考 Pallas 核函数,但我们实际上所做的是编写一个函数,该函数在计算单元附近执行。
在 GPU 上,x_ref
对应于高带宽内存(HBM)中的一个值,当我们执行 x_ref[...]
时,我们将该值从 HBM 复制到静态 RAM(SRAM)中(一般情况下这是一个昂贵的操作!)。然后,我们使用 GPU 向量计算来执行加法,然后将结果值从 SRAM 复制回 HBM。
在 TPU 上,我们做了略有不同的事情。在内核被执行之前,我们从 HBM 中获取值到 SRAM 中。因此,x_ref
对应于 SRAM 中的一个值,当我们执行x_ref[...]
时,我们将该值从 SRAM 复制到寄存器中。然后,我们使用 TPU 向量计算来执行加法,然后将结果值复制回 SRAM。在内核执行完毕后,将 SRAM 中的值复制回 HBM。
我们正在编写特定后端的 Pallas 指南。即将推出!
Pallas 编程模型
在我们的“hello world”示例中,我们编写了一个非常简单的内核。它利用了我们的大小为 8 的数组可以轻松地放入硬件加速器的 SRAM 中这一事实。在大多数实际应用中,情况通常并非如此!
编写 Pallas 内核的一部分是考虑如何处理生活在高带宽内存(HBM,也称为 DRAM)中的大数组,并表达操作这些数组“块”的计算,这些块可以适应 SRAM 中。
网格
要自动“切分”输入和输出,您需要向pallas_call
提供一个grid
和BlockSpec
。
一个grid
是一组整数的元组(例如()
,(2, 3, 4)
或(8,)
),指定了一个迭代空间。例如,网格(4, 5)
将有 20 个元素:(0, 0), (0, 1), ... , (0, 4), (1, 0), ... , (3, 4)
。我们为每个元素运行一次内核函数,这是单程序多数据(SPMD)编程风格。
一个二维网格
当我们向pallas_call
提供一个grid
时,内核将执行prod(grid)
次。每次调用被称为“程序”,为了访问内核当前执行的程序(即grid
的哪个元素),我们使用program_id(axis=...)
。例如,对于调用(1, 2)
,program_id(axis=0)
返回1
,program_id(axis=1)
返回2
。
这里是一个使用grid
和program_id
的内核示例。
def iota_kernel(o_ref):
i = pl.program_id(0)
o_ref[i] = i
现在,我们使用pallas_call
来执行它,还提供了一个额外的grid
参数。
def iota(len: int):
return pl.pallas_call(iota_kernel,
out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
grid=(len,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
在 GPU 上,每个程序在单独的线程上并行执行。因此,我们需要考虑写入 HBM 时的竞争条件。一个合理的方法是编写我们的内核,使不同的程序写入 HBM 中的不同位置,以避免这些并行写入。另一方面,通过并行化计算,我们可以快速执行诸如矩阵乘法之类的操作。
在 TPU 上,程序以并行和顺序(取决于架构)的组合方式执行,因此需要考虑略有不同。
块规格
考虑到grid
和program_id
,Pallas 提供了一种抽象,处理了许多内核中常见的索引模式。为了建立直觉,让我们尝试实现一个矩阵乘法。
在 Pallas 中实现矩阵乘法的一个简单策略是递归实现。我们知道我们的底层硬件支持小矩阵乘法(使用 GPU 和 TPU tensorcores),因此我们只需将大矩阵乘法表示为较小的矩阵乘法。
假设我们有输入矩阵 (X) 和 (Y) 并计算 (Z = XY)。我们首先将 (X) 和 (Y) 表达为块矩阵。(X) 将有“行”块,而 (Y) 将有“列”块。
[\begin{split} \begin{align} X = \end{align} \end{split}][ \begin{align} Y = \end{align} ][\begin{split} \begin{align} Z &= \end{align} \end{split}]
我们的策略是,因为 (Z) 也是一个块矩阵,我们可以将我们 Pallas 内核中的每个程序分配给一个输出块。计算每个输出块相当于在 (X) 的“行”块和 (Y) 的“列”块之间进行较小的矩阵乘法。
要表达这种模式,我们使用 BlockSpec
。BlockSpec
指定每个输入和输出的块形状,以及一个“索引映射”函数,将一组程序索引映射到一个块索引。
BlockSpec
的可视化
举个具体的例子,假设我们想要将两个 (1024, 1024)
矩阵 x
和 y
相乘得到 z
,并且希望将计算并行化为 4 个部分。我们将 z
切分为 4 个 (512, 512)
块,其中每个块使用 (512, 1024) x (1024, 512)
的矩阵乘法计算。为了表达这一点,我们首先使用一个 (2, 2)
的网格(每个程序一个块)。
对于 x
,我们使用 BlockSpec(lambda i, j: (i, 0), (512, 1024))
– 这将 x
切分成“行”块。观察程序实例 (1, 0)
和 (1, 1)
如何选择 x
中的 (1, 0)
块。对于 y
,我们使用其转置版本 BlockSpec(lambda i, j: (0, j), (1024, 512))
。最后,对于 z
,我们使用 BlockSpec(lambda i, j: (i, j), (512, 512))
。
这些 BlockSpec
通过 in_specs
和 out_specs
被传递给 pallas_call
。
在底层,pallas_call
将自动将您的输入和输出划分为每个将传递到内核的块的 Ref
。
def matmul_kernel(x_ref, y_ref, z_ref):
z_ref[...] = x_ref[...] @ y_ref[...]
def matmul(x: jax.Array, y: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
)
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)
注意,这是矩阵乘法的一个非常简单的实现,但可以作为各种优化类型的起点。让我们为我们的矩阵乘法添加一个额外的特性:融合激活。这实际上非常简单!只需将一个高阶激活函数传递到内核中即可。
def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
z_ref[...] = activation(x_ref[...] @ y_ref[...])
def matmul(x: jax.Array, y: jax.Array, *, activation):
return pl.pallas_call(
partial(matmul_kernel, activation=activation),
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], y.shape[1] // 2))
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, y.shape[1] // 2)
),
)(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))
最后,让我们强调 Pallas 的一个很酷的特性:它可以与 jax.vmap
组合使用!要将此矩阵乘法转换为批处理版本,我们只需将其 vmap
化。
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))
Pallas TPU
TPU 特定文档。
指南
-
使用 Pallas 编写 TPU 内核
-
什么是 TPU?
-
值得注意的特性和限制
-
支持的操作
-
-
Pipelining 和
BlockSpec
们-
TPU 及其内存空间
-
使用 VMEM/SMEM 的限制
-
入门:流水线处理
-
Pallas 中的流水线处理
-
处理减少
-
Megacore 配置中的 TPUs
-
结论
-
使用 Pallas 编写 TPU 内核
本页关注试图在 Google TPU 上运行 Pallas 内核时的重要细节。首先,TPU 后端仍处于实验阶段,并且只接受 JAX NumPy 的子集。此外,为 TPU 编写高性能代码可能需要仔细考虑硬件的本机能力。虽然许多对硬件不自然的模式将被接受,但它们最终可能需要软件模拟,并可能减慢计算速度。
警告
此功能仍应视为实验性功能,因为工作仍在进行中(特别是在改进错误消息方面)。
注意
虽然此处描述的所有功能都是实验性的,但我们仍然非常认真地维护其正确性。因此,在尝试编写 TPU 内核时可能看到“未实现”错误并不罕见。但是,如果编译器接受了内核,它必须返回预期的结果。
如果您看到意外的输出,请将其与传递interpret=True
到pallas_call
的内核运行进行比较。如果结果不一致,请提交错误报告。
什么是 TPU?
TPU 是 Google 开发的硬件加速器。您可以将 TPU 视为专门用于机器学习工作负载的 GPU。因此,它们的架构有相当大的差异。然而,我们相信 Pallas 可以使您轻松开始编写 TPU 内核,即使您没有完全理解底层硬件也是如此。话虽如此,深入了解硬件将确实使编写高性能内核变得更加容易。
简言之,TPU 与 GPU 的主要区别在于 TPU 是顺序机器,具有非常宽的向量寄存器(类似于 CPU!)。与此同时,它们允许软件安排某些操作在后台执行,使其与主指令流异步执行。这包括 HBM 内存访问(无法直接发出,而是必须通过 DMA 子单元预取到较低层次的内存层次结构)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)。
如果您对详细了解 TPU 架构感兴趣,我们建议阅读多年来发表的一系列论文集。虽然许多论文谈论特定的 TPU 代,但其中许多描述的思想也适用于后续代。
值得注意的属性和限制
BlockSpec
s 和网格迭代
在 Pallas 中,BlockSpec
s 通常按预期行为——每次核心体调用都会访问输入的片段,并且旨在初始化输出的一个片段。
警告
并非所有的窗口形状都受支持。如果你的输入的最后两个维度分别大于 8 和 128,那么这些维度中的窗口形状必须是对应因子的倍数。如果输入维度较小,则窗口应跨越整个维度。
Pallas TPU 核心的一个有趣方面是它们处理内存空间的方式:虽然pallas_call
的输入通常驻留在 HBM(主 TPU 内存)中,但传递到核心体的引用将指向内存层次结构较低的缓冲区(VMEM 或 SMEM)。这使得核心体能够以非常高的速度读写它们,而所有与 HBM 的通信(具有非常高的延迟)由编译器处理并与计算重叠。
此外,与 GPU 相比,TPU 实际上是高度序列化的机器。因此,网格通常不是并行处理的,而是按字典顺序顺序处理(尽管请参阅多核 TPU 配置部分的例外情况)。这解锁了一些有趣的功能:
-
当两个(按字典顺序)连续的网格索引使用相同输入的片段时,第二次迭代的 HBM 传输将被跳过,因为数据已经可用。
-
多个核心体调用可以向输出的同一片段写入,而不会有任何竞态条件的风险。但我们确实要求写入特定片段的所有调用是连续的。
关于输出的“连续”限制通常意味着网格维度的某些前缀总是变化,而调用需要访问的输出窗口对于其余后缀保持不变。
例如,在实现矩阵乘法的 Pallas TPU 核心时,通常会使用三维网格:前两个维度对应于沿左操作数的第一轴和第二操作数的第二轴切片。第三和最后网格轴将瓦片化减少维度。与减少维度对应的网格轴必须是最后一个,因为输出窗口沿此轴不变。输出引用随后可用作部分结果的累加器。
注意
对于这样一个低级内存层次结构(16MB+),VMEM 相当大,这使得可以使用较大的窗口大小。通常情况下,窗口大小越大,最终硬件利用率就越好。然而,可能会指定一个窗口大小,该大小(加上保存溢出矢量寄存器所需的空间)超过了 VMEM 的大小。在这种情况下,您可能会看到一个低级编译器错误消息,抱怨内存不足错误。
维度排序是有意义的
在 JAX 程序中,jax.jit
内部数组的排序通常不会影响性能,因为编译器可以自由地重新排列它们。但是,由于 Pallas 旨在暴露更低级的功能,维度顺序对生成的代码质量有很大影响。
请记住,TPU 主要在 2D 矢量寄存器上执行大部分计算。Pallas TPU 只会考虑将中间数组的最后两个维度映射到这些矢量寄存器维度(子通道和通道)。形状为(n, 1, 1)
的数组保证需要至少n
个矢量寄存器来表示。如果n
变得太大,则可能会导致溢出,并由于过大的内存占用而导致 VMEM 内存不足错误。但这也可能不会发生 — 低级编译器可以重新排列指令以降低寄存器压力,并且实际上在这方面做得非常好。尽管如此,保持最后两个维度大(特别是最后一个维度),同时使前导维度保持小是一个很好的经验法则。
多核 TPU 配置
在更新的 TPU 生成中,芯片上的两个核心通常被抽象为单个设备。为了利用多个核心,Pallas 必须打破顺序网格执行的保证,并且需要在核心上并行化一个网格轴。这是一个选择加入的过程。为了允许这样做,pallas_call
需要一个额外的名为dimension_semantics
的参数:
该参数是一个列表,其条目数量与网格中的轴数量相同。只有parallel
维度可以在核心上分区。作为一个经验法则,维度是并行的,除非输出窗口不变。因此,dimension_semantics
始终是一些parallel
轴的数字,后跟一些arbitrary
轴的数字。
尽管在 2 核 TPU 设备上分区内核通常会导致 2 倍速度提升,但实际上可能会显著小于此值。特别是如果体的不同实例具有非常不同的成本,这一点尤为真实。如果所有昂贵的步骤都映射到一个核心,而所有廉价的步骤都分配给另一个核心,则第二个核心将在第一个完成其任务之前处于空闲状态。
Pallas TPU 通常偏好将大小为 TPU 核心数量倍数的轴进行分区,并且更喜欢分区主导的网格轴。
将操作数放入 SMEM
大多数 TPU 计算将在向量单元上进行。然而,有许多情况下进行一些标量操作是有用的,例如执行控制流。因此,TPU 配备了一个单独的标量单元,并附有一个单独的标量存储器(SMEM)。按照一个经验法则,用于执行控制流决策的任何数据应放置在 SMEM 中。
SMEM 是一种低延迟内存,支持随机访问,但只能用单个指令读写 32 位值(与 VMEM 事务的 4KBi 粒度相比非常小,但由于没有对齐要求而更加灵活!)。
当实现不按规则模式访问输入块的内核时,标量内存也非常有用,例如编写块稀疏内核时。在 Pallas 中,可以通过将pallas_call
的grid
参数替换为具有非零num_scalar_prefetch
参数的PrefetchScalarGridSpec
的grid_spec
来实现这一点。如果num_scalar_prefetch
为n
,那么pallas_call
的前n
个参数将放置在 SMEM 中。对于这些参数,不应指定任何BlockSpec
。但是,对于所有后续参数的BlockSpec
,不仅会收到网格索引,还会收到领先操作数的 SMEM 引用。
注意
我们正在努力实现此功能的示例。敬请关注!
支持的数据类型
目前,Pallas TPU 仅支持以下数据类型:
-
jnp.float32
-
jnp.bfloat16
-
jnp.int*
(所有精度,除了jnp.int4
) -
jnp.uint*
(所有精度)
计算放置
所有标量(即 0D)数组将存储在标量寄存器中,并在标量核心上执行操作。所有其他操作(甚至是对单个元素但是 1D+数组的操作)将在向量核心上执行。
支持的操作
矩阵乘法
矩阵乘法始终以float32
格式生成结果。如果您的输入不是 float32,建议使用lax.dot
并将preferred_element_type
设置为jnp.float32
。
当使用lax.dot_general
时,可以将矩阵乘法操作数的最后两个维度的转置融合到操作中,这可以提高整体内核性能。
精度控制
Pallas TPU 的降低考虑到了jax.default_matmul_precision
。为了获得最佳性能(和最低精度),请使用bfloat16
。如果您关心数值精度,可能需要将精度设置为float32
。
警告
即使将 32 位操作数传递给矩阵乘法,除非请求float32
精度,否则它们将会被四舍五入为bfloat16
。
转置
如果值至少有 4 个维度,则除了最后两个轴以外的任意转置都是免费的。否则,仅实现了最后两个轴的转置。请注意,一些最后两个维度的转置可以融合到矩阵乘法中。
访问内存
可以读取或更新引用的任意片段,受实现约束的限制。目前,对于宽度为 32 位的输入没有限制,但只支持某些更窄类型的切片模式。总是支持最后两个维度中分别是 8 和 128 的倍数的对齐读写。
通常在向量内存的读写发生在形状为 (8, 128)
的瓦片上。因此,当读取或写入至少有两个维度的引用时,最佳性能是在内存访问的基础偏移具有瓦片可整除的索引,并且读取区域的大小是瓦片大小的倍数。
逐元素操作
支持许多逐元素操作。值得注意的是,硬件通常仅支持使用 32 位类型进行逐元素计算。在加载使用较低精度类型的操作数时,通常应先将其升级为 32 位类型再应用逐元素操作。
值得注意的是,它们的成本可能显著不同。因此,我们列出了三类支持的操作:廉价(🟢)、中等(🌕)和昂贵(🔴)。
操作 | 成本 |
---|---|
jnp.add ,+ |
🟢 |
jnp.sub ,- |
🟢 |
jnp.mul ,* |
🟢 |
/ ,// ,% |
🌕 |
jnp.max ,jnp.min |
🟢 |
jnp.where (选择) |
🟢 |
jnp.abs |
🟢 |
` | , ^`,`&`,`~` |
<< ,>> |
🟢 |
比较运算(== ,...) |
🟢 |
类型转换(.astype ) |
🟢 |
jnp.exp |
🌕 |
jnp.tanh |
🌕 |
jnp.pow |
🌕 |
jnp.sin |
🔴 |
jnp.cos |
🔴 |
许多 JAX 函数是基于其他 JAX 原语实现的,因此此列表可能不完整。例如,jax.nn.relu
是基于比较实现的,而 jnp.where
在 Pallas 内核中也能工作。
数组构造函数
所有常数数组构造函数都受支持(jnp.ones
,jnp.zeros
,jnp.full
)。特别是,截至今天,jax.random
模块与 Pallas 不 兼容。
归约
支持求和、最大值和最小值的归约,但一次只能在一个数组轴上进行。
对最后一个数组维度的归约通常是最慢的。对倒数第二个维度的归约更快,但仍比前面的维度慢。
广播
广播的性能特性与归约非常相似。总是支持除了最后两个维度之外的所有广播,且是免费的。沿着倒数第二个维度进行广播较慢,而沿着最后一个维度进行广播最慢。
重塑
如常地,所有维度除了最后两个维度的重塑都是支持的且是免费的。
唯一支持的情况是当重塑可以修改数组的最后两个维度时,即(1)某些前导维度展平到倒数第二个维度,或者(2)它添加了刚刚由归约移除的维度。
控制流程
目前,TPU 后端对控制流的支持有限。目前支持的函数有cond
、fori_loop
和for_loop
。然而,在编译时,循环原语会完全展开,因此请尽量保持循环执行次数合理小。
过度使用控制流可能导致低级代码生成中的显著回归,建议尽量将多个计算密集型操作挤入一个基本块中。
管道化和块规范
在本指南中,我们将介绍 TPU 中的内存空间工作原理,并展示如何在 Pallas 中编写可以将内存 I/O 与计算重叠的流水线。
#@title Imports
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
TPU 及其内存空间
TPU 和其 TensorCore 包括内存空间(用于存放数组的区域)、寄存器(临时存储标量和数组值的地方)和计算单元(用于处理寄存器中的值的计算单元)。下图显示了一个 TPU 的结构,其中 x
和 y
是存储在高带宽存储器(HBM)中的数组:
让我们更详细地讨论这个图表的组成部分:
-
内存空间:TPU 拥有高带宽内存(HBM),这通常被称为“设备内存”。还有向量内存(VMEM),一个用于存储向量和数组值的缓存,以及标量内存(SMEM),一个设计用于存储标量值的缓存。
-
寄存器:TensorCore 拥有两种主要类型的寄存器:向量寄存器(VREGs)存储数组值,标量寄存器(SREGs)存储标量值。值可以从相应的缓存(VREG 的 VMEM 和 SREG 的 SMEM)加载到内存中。
-
计算单元:TensorCore 包括标量单元、向量单元(VPU)和矩阵单元(MXU),用于进行数值计算。计算单元操作位于 SREG 和 VREG 中的值,并将输出值也存储在这些寄存器中。
为了在我们存储在 HBM 中的值 x
和 y
上执行矢量化计算,我们需要:
-
将值
x
和y
复制到 VMEM 中。 -
从 VMEM 中加载值到 VREG 中。
-
使用 VPU 或 MXU 执行计算,并将输出存储在 VREG 中。
-
将输出 VREG 中的值存储到 VMEM 中。
-
将 VMEM 中的输出值复制回 HBM。
让我们实现一个 Pallas 函数来完成这些操作!
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
# Load x and y from VMEM into VREGs
x_vregs = x_vmem_ref[:, :]
y_vregs = y_vmem_ref[:, :]
# Execute a vectorized add
z_vregs = x_vregs + y_vregs
# Store the output values in VREGs back into VMEM
z_vmem_ref[:, :] = z_vregs
def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
# pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.
# It will then copy `x` and `y` from HBM into VMEM.
z = pl.pallas_call(
add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
# pallas_call will also copy the output from VMEM back into HBM.
return z
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们编写了两个函数:add_matrices_kernel
和 add_matrices
。
add_matrices_kernel
操作使用在 VMEM 中存在的 Ref
。从 VMEM 的 Ref
加载会产生一个存在于 VREG 中的值。VREG 中的值的行为类似于 jax.Array
,我们可以在其上使用 jnp
和 jax.lax
操作来产生新的值,这些新值仍然存在于 VREG 中。当我们产生想要返回的值时,我们将它们存储在输出的 VMEM Ref
中。
add_matrices
函数作用于 jax.Array
,并返回一个 jax.Array
。在函数内部,我们将 x
和 y
传递给 pallas_call
。pallas_call
负责将 x
和 y
复制到 VMEM 中,并分配内核操作的 VMEM 缓冲区(包括分配 z_vmem_ref
,输出的 VMEM 缓冲区)。内核函数运行完成后,pallas_call
还将 z_vmem_ref
中的值复制到 HBM,最终输出一个 jax.Array
。
使用 VMEM/SMEM 的限制
Pallas 公开了对低级内存空间(如 VMEM 和 SMEM)的访问,但编写利用它们的内核需要考虑一些因素。
-
内存容量。VMEM 和 SMEM 都很小!v4 TPU 上的 VMEM 只有 16MiB,SMEM 的范围在几十到几百 KiB。如果我们的数组太大,甚至无法完全放入 VMEM 中。举个例子,一个
f32[2048, 2048]
数组就是 16MiB,因此我们上面的核心代码无法处理超过中等大小的数组。 -
内存带宽。从 HBM 和 VMEM 复制数据需要很长时间,至少与大多数计算指令相比是如此。上面的
add_matrices
函数很可能在复制 HBM 和 VMEM 之间花费的时间比执行加法本身要多。
考虑到这两个约束条件,我们必须重新思考如何提高 TPU 的性能策略。
引言:流水线
在一个行动中处理内存容量和带宽约束的流水线计算提供了一种方法。我们所说的流水线是什么意思?
目标是:并行复制到/从 HBM 和 VMEM 同时利用我们的计算单元。但在我们的程序中,这种方式相对困难,因为我们在开始进行计算之前先复制了所有的 x
和 y
,从而在复制和计算之间创建了依赖关系。
然而,如果我们可以将计算分成几个子计算(例如,当我们将两个矩阵相加时,可以将原始矩阵的“块”相加在一起),我们现在可以将其中一个子计算的复制与另一个计算的执行重叠起来。让我们通过一个简单的例子来演示:
假设我们将数组 x
和 y
分成 x1, x2
和 y1, y2
(例如,沿着主轴进行分割,每个输入结果为两个 (256, 512)
的数组)。现在我们可以执行以下流水线计算。
-
复制
x1
和y1
到 VMEM 中。 -
开始将
x2
和y2
复制到 VMEM。 -
从 VMEM 加载
x1, y1
到 VREGs 中。 -
使用计算单元执行
z1 = x1 + y1
。 -
将
z1
存储到 VMEM 中。 -
开始将
z1
从 VMEM 复制回到 HBM。 -
等待
x2, y2
被复制到 VMEM。 -
从 VMEM 加载
x2, y2
到 VREGs 中。 -
使用计算单元执行
z2 = x2 + y2
。 -
将
z2
存储到 VMEM 中。 -
等待
z1
被复制到 HBM。 -
开始将
z2
从 VMEM 复制回到 HBM。 -
等待
z2
被复制到 HBM。
在这里进行计算时,我们总是异步复制某些内容。这意味着复制过程中的一些时间并不会浪费。
决定流水线计算效率的两个最重要的因素是 a) 我们需要执行多少浮点运算(FLOPs)和 b) 我们需要复制多少字节以执行该计算。这两者的比率(FLOPs/内存使用量)称为操作的算术强度,并确定我们的流水线是计算受限还是内存受限。
Pallas 中的流水线
我们如何在 Pallas 中实现像上面那样的管道?这似乎是一系列复杂的异步数据操作和执行内核,手动实现可能会很麻烦。不要担心!Pallas 提供了一个 API 来表达管道,而不需要太多样板文件,即通过grid
和BlockSpec
。
grid
,又名循环中的内核
看看在上述流水线示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是在不同的输入上。这个泛化版本是在同一个内核上多次执行循环。pallas_call
提供了一个选项来实现这一点。
循环中的迭代次数由pallas_call
的grid
参数指定。在概念上:
pl.pallas_call(some_kernel, grid=n)(...)
映射到
for i in range(n):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
网格可以推广为多维,对应于嵌套循环。例如,
pl.pallas_call(some_kernel, grid=(n, m))(...)
等价于
for i in range(n):
for j in range(m):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
这可以推广到任意整数元组(长度为d
的网格将对应于d
个嵌套循环)。
BlockSpec
,又称如何分块输入
为了自动管道化我们的计算,我们需要向 Pallas 提供的下一部分信息是如何对其进行分块的信息。具体来说,我们需要提供一个映射,将循环的迭代映射到操作哪些输入和输出块。BlockSpec
正是这两个信息。
首先,我们为我们的输入选择一个block_shape
。在上面的流水线示例中,我们有(512, 512)
形状的数组,并沿着主维度分成两个(256, 512)
形状的数组。在这个管道中,我们的block_shape
将是(256, 512)
。
然后,我们提供一个index_map
函数,将迭代空间映射到块。具体来说,在上述管道中,第 1 次迭代我们想选择x1
,第 2 次迭代我们想使用x2
。可以用以下index_map
表达:
def x_index_map(i):
return (i, 0)
然后,我们将构建BlockSpec
:
block_spec = pl.BlockSpec(x_index_map, (256, 512))
BlockSpec
对于y
和z
与对x
的BlockSpec
将是相同的。
汇总
我们通过grid
、in_specs
和out_specs
将这些参数提供给pallas_call
(in_specs
对应于位置参数的元组,out_specs
对应于输出)。
def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,))(x, y)
add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
我们只需向原始函数添加了少量代码以添加自动管道,但BlockSpec
和grid
做了大量的重复工作!
它是如何工作的?好吧,BlockSpec
提供足够的信息来开始从 HBM 到 VMEM 预取我们输入的块。例如,如果我们开始grid
的第i
次迭代,我们可以将i + 1
传递给index_map
函数,以获取下一次迭代所需的块。然后,我们可以开始这些块的异步复制。类似地,对于输出,我们可以在开始当前迭代的输出复制之前等待上一次迭代的输出复制完成。
参数化管道
在我们的内核中,参数化块形状是常见的。当优化 Pallas 内核的性能时,块大小可能是最重要的参数!它们允许我们控制管道流程(例如,选择较小的块会在我们的流水线循环中增加更多的迭代,每个迭代的工作量较小)。
此外,我们还可以沿第二维(目前仅沿第一维进行拆分)划分输入和输出。让我们编写一个更通用的内核,处理这两个特性。
def add_matrices_pipelined_2d(
x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
m, n = x.shape
block_spec = pl.BlockSpec(lambda i, j: (i, j), (bm, bn))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(m // bm, n // bn),
)(x, y)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)
处理减少
如何使用pallas_call
实现类似jnp.sum
的功能?具体来说,我们希望在减少维度上进行流水线处理。
以将(8, 512, 512)
形状的数组减少到(512, 512)
形状为例。
x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
要使用pallas_call
实现这一点,我们可以使用大小为(8,)
的网格,并在每次迭代i
中将x[i]
加载到 VMEM 中。然后我们可以将x[i]
添加到输出 VMEM 缓冲区中。让我们先天真地实现这一点。
# Warning: this implementation is incorrect!
def naive_sum_kernel(x_ref, o_ref):
o_ref[...] += x_ref[...]
def naive_sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
naive_sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
...,
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.],
[9., 9., 9., ..., 9., 9., 9.]], dtype=float32)
注意我们如何设置BlockSpecs
:我们将(512, 512)
维度完全加载到 VMEM 中(在这里没有流水线),但在块形状的index_map
中每次迭代选择x
的第i
维度。在块形状中,我们对该维度使用None
,这表示我们正在从x
中选择一个单维度,我们希望在内核中将其挤压掉。因此,在 VMEM 中,x_ref
也是(512, 512)
形状。
out_spec
使用lambda i: (0, 0)
作为其index_map
,指示在管道过程中o_ref
保持不变。这意味着我们可以通过从中读取并向其写入来更新其值。或者可以吗?实际上有一个问题:o_ref
最初是垃圾,这意味着我们将累积到垃圾中。这将导致整体函数输出不正确的值!
因此,每当我们在内核中进行减少操作时,我们需要确保初始化存储减少值的Ref
。我们可以通过在迭代 0 时有条件地向out_ref
写入值来实现这一点。我们可以利用辅助函数pl.when
(一个方便的包装器,围绕jax.lax.cond
和pl.program_id
进行操作),查询我们在网格轴上的迭代。
def sum_kernel(x_ref, o_ref):
@pl.when(pl.program_id(axis=0) == 0)
def _():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += x_ref[...]
def sum(x: jax.Array) -> jax.Array:
grid, *out_shape = x.shape
return pl.pallas_call(
sum_kernel,
grid=grid,
# None in `block_shape` means we pick a size of 1 and squeeze it away
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (None, *out_shape))],
out_specs=pl.BlockSpec(lambda i: (0, 0), out_shape),
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
)(x)
sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
...,
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.],
[8., 8., 8., ..., 8., 8., 8.]], dtype=float32)
此sum
函数现在输出正确的值!
关于 Pallas 中减少的最后一件事是它们必须在我们网格的最小维度(最右边)中完成(在上面的示例中,我们的网格是 1 维的,因此我们在其最小维度上进行减少)。这是因为 Pallas 生成的管道不会从 HBM 读取输出。一旦将输出值写回到 HBM,就不能重新访问它。因此,您不能在具有任何重新访问的网格维度上进行减少,因此所有减少操作都需要在最右维度上进行。
Megacore 配置的 TPU
一些 TPU 芯片有两个 TensorCores,但对 JAX 用户来说,它们表现为一个设备。这被称为“megacore”。这两个独立的 TensorCores 分别拥有自己的 VMEM、VREGs、SMEM、SREGs 和计算单元,但共享 HBM。
从概念上讲,Megacore 中的 TPU 行为类似于非常简单的 GPU,即只有两个线程。我们如何修改我们的内核以同时利用两个 TensorCores?
基本思想是,如果我们在计算中有尴尬地并行的维度,我们可以将这些维度分配到 TensorCores 上。我们可以通过向 pallas_call
提供一个称为 dimension_semantics
的注释来指示哪些维度是可并行化的。
def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
block_spec = pl.BlockSpec(lambda i: (i, 0), (256, 512))
return pl.pallas_call(
add_matrices_kernel,
out_shape=x,
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))))(
x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
...,
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.],
[2., 2., 2., ..., 2., 2., 2.]], dtype=float32)
dimension_semantics
应该是一个与 grid
长度相同的元组,其中每个条目都是"parallel"
或"arbitrary"
。"parallel"
表示对 Pallas 来说,与该维度对应的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary"
表示对 Pallas 来说,在这个网格维度上不能做任何假设,因此不能并行化。
通过指定 dimension_semantics
,我们现在可以同时在每个 TensorCore 上执行内核。Pallas 将自动处理网格的分割。
请注意,Megacore 目前仅适用于 TPU
v4
和 TPUv5p
。在其他平台上提供dimension_semantics
注释是一个空操作,但不指定它将导致只使用一个 TensorCore(即使有多个可用)。
结论
在本指南中,我们讨论了如何使用 pallas_call
、grid
和 BlockSpec
表达 TPU 管道。我们讨论了如何通过多维网格表达嵌套循环,并在减少开始时初始化累加器的情况下处理归约。我们还学习了如何通过向内核添加注释来处理 Megacore。
读者留给的练习:
-
尝试实现一个
sum
内核,该内核也可以管道化其他维度 -
还要将
add
内核和sum
内核添加到 Megacore 支持中。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!
2022-06-21 数据科学 IPython 笔记本 9.2 NumPy 简介
2020-06-21 PyTorch 1.0 中文官方教程:使用字符级别特征的RNN网络生成姓氏