JAX-中文文档-四-

JAX 中文文档(四)

原文:jax.readthedocs.io/en/latest/

理解 Jaxpr

原文:jax.readthedocs.io/en/latest/jaxpr.html

更新日期:2020 年 5 月 3 日(提交标识为 f1a46fe)。

从概念上讲,可以将 JAX 转换看作是首先对要转换的 Python 函数进行追踪特化,使其转换为一个小型且行为良好的中间形式,然后使用特定于转换的解释规则进行解释。JAX 能够在一个如此小的软件包中融合如此多的功能,其中一个原因是它从一个熟悉且灵活的编程接口(Python + NumPy)开始,并使用实际的 Python 解释器来完成大部分繁重的工作,将计算的本质提炼为一个简单的静态类型表达语言,具有有限的高阶特性。那种语言就是 jaxpr 语言。

并非所有 Python 程序都可以以这种方式处理,但事实证明,许多科学计算和机器学习程序可以。

在我们继续之前,有必要指出,并非所有的 JAX 转换都像上述描述的那样直接生成一个 jaxpr;有些转换(如微分或批处理)会在追踪期间逐步应用转换。然而,如果想要理解 JAX 内部工作原理,或者利用 JAX 追踪的结果,理解 jaxpr 是很有用的。

一个 jaxpr 实例表示一个带有一个或多个类型化参数(输入变量)和一个或多个类型化结果的函数。结果仅依赖于输入变量;没有从封闭作用域中捕获的自由变量。输入和输出具有类型,在 JAX 中表示为抽象值。代码中有两种相关的 jaxpr 表示,jax.core.Jaxprjax.core.ClosedJaxprjax.core.ClosedJaxpr 表示部分应用的 jax.core.Jaxpr,当您使用 jax.make_jaxpr() 检查 jaxpr 时获得。它具有以下字段:

  • jaxpr 是一个 jax.core.Jaxpr,表示函数的实际计算内容(如下所述)。
  • consts 是一个常量列表。

jax.core.ClosedJaxpr 最有趣的部分是实际的执行内容,使用以下语法打印为 jax.core.Jaxpr

Jaxpr ::= { lambda Var* ; Var+. let
              Eqn*
            in  [Expr+] } 

其中:

  • jaxpr 的参数显示为用 ; 分隔的两个变量列表。第一组变量是引入的用于表示已提升的常量的变量。这些称为 constvars,在 jax.core.ClosedJaxpr 中,consts 字段保存相应的值。第二组变量称为 invars,对应于跟踪的 Python 函数的输入。

  • Eqn* 是一个方程列表,定义了中间变量,这些变量指代中间表达式。每个方程将一个或多个变量定义为在某些原子表达式上应用基元的结果。每个方程仅使用输入变量和由前面的方程定义的中间变量。

  • Expr+:是 jaxpr 的输出原子表达式(文字或变量)列表。

方程式打印如下:

Eqn  ::= Var+ = Primitive [ Param* ] Expr+ 

其中:

  • Var+是要定义为基元调用的输出的一个或多个中间变量(某些基元可以返回多个值)。

  • Expr+是一个或多个原子表达式,每个表达式可以是变量或字面常量。特殊变量unitvar或字面unit,打印为*,表示在计算的其余部分中不需要的值已被省略。也就是说,单元只是占位符。

  • Param*是基元的零个或多个命名参数,打印在方括号中。每个参数显示为Name = Value

大多数 jaxpr 基元是一阶的(它们只接受一个或多个Expr作为参数):

Primitive := add | sub | sin | mul | ... 

jaxpr 基元在jax.lax模块中有文档。

例如,下面是函数func1生成的 jaxpr 示例

>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
...    temp = first + jnp.sin(second) * 3.
...    return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
 c:f32[8] = sin b
 d:f32[8] = mul c 3.0
 e:f32[8] = add a d
 f:f32[] = reduce_sum[axes=(0,)] e
 in (f,) } 

在这里没有 constvars,ab是输入变量,它们分别对应于firstsecond函数参数。标量文字3.0保持内联。reduce_sum基元具有命名参数axes,除了操作数e

请注意,即使执行调用 JAX 的程序构建了 jaxpr,Python 级别的控制流和 Python 级别的函数也会正常执行。这意味着仅因为 Python 程序包含函数和控制流,生成的 jaxpr 不一定包含控制流或高阶特性。

例如,当跟踪函数func3时,JAX 将内联调用inner和条件if second.shape[0] > 4,并生成与之前相同的 jaxpr

>>> def func2(inner, first, second):
...   temp = first + inner(second) * 3.
...   return jnp.sum(temp)
...
>>> def inner(second):
...   if second.shape[0] > 4:
...     return jnp.sin(second)
...   else:
...     assert False
...
>>> def func3(first, second):
...   return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
 c:f32[8] = sin b
 d:f32[8] = mul c 3.0
 e:f32[8] = add a d
 f:f32[] = reduce_sum[axes=(0,)] e
 in (f,) } 

处理 PyTrees

在 jaxpr 中不存在元组类型;相反,基元接受多个输入并产生多个输出。处理具有结构化输入或输出的函数时,JAX 将对其进行扁平化处理,并在 jaxpr 中它们将显示为输入和输出的列表。有关更多详细信息,请参阅 PyTrees(Pytrees)的文档。

例如,以下代码产生与前面看到的相同的 jaxpr(具有两个输入变量,每个输入元组的一个)

>>> def func4(arg):  # Arg is a pair
...   temp = arg[0] + jnp.sin(arg[1]) * 3.
...   return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
 c:f32[8] = sin b
 d:f32[8] = mul c 3.0
 e:f32[8] = add a d
 f:f32[] = reduce_sum[axes=(0,)] e
 in (f,) } 

常量变量

jaxprs 中的某些值是常量,即它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示;非标量数组常量则提升到顶级 jaxpr,其中它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)在书面上的约定中有所不同。

高阶基元

jaxpr 包括几个高阶基元。它们更复杂,因为它们包括子 jaxprs。

条件语句

JAX 可以跟踪普通的 Python 条件语句。要捕获动态执行的条件表达式,必须使用jax.lax.switch()jax.lax.cond()构造函数,它们的签名如下:

lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B

lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B 

这两个都将在内部绑定一个名为 cond 的原始。jaxprs 中的 cond 原始反映了 lax.switch() 更一般签名的更多细节:它接受一个整数,表示要执行的分支的索引(被夹在有效索引范围内)。

例如:

>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
...   return lax.switch(index, [lambda x: x + 1.,
...                             lambda x: x - 2.,
...                             lambda x: x + 3.],
...                     arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
 c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
 d:i32[] = clamp 0 c 2
 e:f32[] = cond[
 branches=(
 { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
 { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
 { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
 )
 linear=(False,)
 ] d b
 in (e,) } 

cond 原始有多个参数:

  • branches 是对应于分支函数的 jaxprs。在这个例子中,这些函数分别使用一个输入变量 x
  • linear 是一个布尔值元组,由自动微分机制内部使用,用于编码在条件语句中线性使用的输入参数。

cond 原始的上述实例接受两个操作数。第一个(d)是分支索引,然后 b 是要传递给 branches 中任何 jaxpr 的操作数(arg)。

另一个例子,使用 lax.cond()

>>> from jax import lax
>>>
>>> def func7(arg):
...   return lax.cond(arg >= 0.,
...                   lambda xtrue: xtrue + 3.,
...                   lambda xfalse: xfalse - 3.,
...                   arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
 b:bool[] = ge a 0.0
 c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
 d:f32[] = cond[
 branches=(
 { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
 { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
 )
 linear=(False,)
 ] c a
 in (d,) } 

在这种情况下,布尔谓词被转换为整数索引(0 或 1),branches 是对应于假和真分支的 jaxprs,按顺序排列。同样,每个函数都使用一个输入变量,分别对应于 xfalsextrue

下面的示例展示了当分支函数的输入是一个元组时,以及假分支函数包含被作为常量 hoisted 的 jnp.ones(1) 的更复杂情况

>>> def func8(arg1, arg2):  # arg2 is a pair
...   return lax.cond(arg1 >= 0.,
...                   lambda xtrue: xtrue[0],
...                   lambda xfalse: jnp.array([1]) + xfalse[1],
...                   arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
 e:bool[] = ge b 0.0
 f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
 g:f32[1] = cond[
 branches=(
 { lambda ; h:i32[1] i:f32[1] j:f32[]. let
 k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
 l:f32[1] = add k j
 in (l,) }
 { lambda ; m_:i32[1] n:f32[1] o:f32[]. let  in (n,) }
 )
 linear=(False, False, False)
 ] f a c d
 in (g,) } 

虽然

就像条件语句一样,Python 循环在追踪期间是内联的。如果要捕获动态执行的循环,必须使用多个特殊操作之一,jax.lax.while_loop()(一个原始)和 jax.lax.fori_loop()(一个生成 while_loop 原始的辅助程序):

lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C 

在上述签名中,“C”代表循环“carry”值的类型。例如,这里是一个 fori 循环的示例

>>> import numpy as np
>>>
>>> def func10(arg, n):
...   ones = jnp.ones(arg.shape)  # A constant
...   return lax.fori_loop(0, n,
...                        lambda i, carry: carry + ones * 3. + arg,
...                        arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
 c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
 d:f32[16] = add a c
 _:i32[] _:i32[] e:f32[16] = while[
 body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
 k:i32[] = add h 1
 l:f32[16] = mul f 3.0
 m:f32[16] = add j l
 n:f32[16] = add m g
 in (k, i, n) }
 body_nconsts=2
 cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
 r:bool[] = lt o p
 in (r,) }
 cond_nconsts=0
 ] c a 0 b d
 in (e,) } 

while 原始接受 5 个参数:c a 0 b d,如下所示:

  • 0 个常量用于 cond_jaxpr(因为 cond_nconsts 为 0)
  • 两个常量用于 body_jaxprca
  • 初始携带值的 3 个参数

Scan

JAX 支持数组元素的特殊形式循环(具有静态已知形状)。由于迭代次数固定,这种形式的循环易于反向可微分。这些循环是用 jax.lax.scan() 函数构造的:

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) 

这是以 Haskell 类型签名 的形式编写的:C 是扫描携带的类型,A 是输入数组的元素类型,B 是输出数组的元素类型。

对于下面的函数 func11 的示例考虑

>>> def func11(arr, extra):
...   ones = jnp.ones(arr.shape)  #  A constant
...   def body(carry, aelems):
...     # carry: running dot-product of the two arrays
...     # aelems: a pair with corresponding elements from the two arrays
...     ae1, ae2 = aelems
...     return (carry + ae1 * ae2 + extra, carry)
...   return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
 c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
 d:f32[] e:f32[16] = scan[
 _split_transpose=False
 jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
 j:f32[] = mul h i
 k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
 l:f32[] = add k j
 m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
 n:f32[] = add l m
 in (n, g) }
 length=16
 linear=(False, False, False, False)
 num_carry=1
 num_consts=1
 reverse=False
 unroll=1
 ] b 0.0 a c
 in (d, e) } 

linear 参数描述了每个输入变量在主体中是否保证线性使用。一旦扫描进行线性化,将有更多参数线性使用。

scan 原始接受 4 个参数:b 0.0 a c,其中:

  • 其中一个是主体的自由变量
  • 其中一个是携带的初始值
  • 接下来的两个是扫描操作的数组。

XLA_call

call 原语来源于 JIT 编译,它封装了一个子 jaxpr 和指定计算应在哪个后端和设备上运行的参数。例如

>>> from jax import jit
>>>
>>> def func12(arg):
...   @jit
...   def inner(x):
...     return x + arg * jnp.ones(1)  # Include a constant in the inner function
...   return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))  
{ lambda ; a:f32[]. let
 b:f32[] = sub a 2.0
 c:f32[1] = pjit[
 name=inner
 jaxpr={ lambda ; d:f32[] e:f32[]. let
 f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
 g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
 h:f32[1] = mul g f
 i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
 j:f32[1] = add i h
 in (j,) }
 ] a b
 k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
 l:f32[1] = add k c
 in (l,) } 

XLA_pmap

如果使用 jax.pmap() 变换,要映射的函数是使用 xla_pmap 原语捕获的。考虑这个例子

>>> from jax import pmap
>>>
>>> def func13(arr, extra):
...   def inner(x):
...     # use a free variable "extra" and a constant jnp.ones(1)
...     return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
...   return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
 c:f32[1,3] = xla_pmap[
 axis_name=rows
 axis_size=1
 backend=None
 call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
 f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
 g:f32[3] = add e f
 h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
 i:f32[3] = add g h
 j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
 k:f32[3] = div i j
 in (k,) }
 devices=None
 donated_invars=(False, False)
 global_axis_size=1
 in_axes=(None, 0)
 is_explicit_global_axis_size=False
 name=inner
 out_axes=(0,)
 ] b a
 in (c,) } 

xla_pmap 原语指定了轴的名称(参数 axis_name)和要映射为 call_jaxpr 参数的函数体。此参数的值是一个具有 2 个输入变量的 Jaxpr。

参数 in_axes 指定了应该映射哪些输入变量和哪些应该广播。在我们的例子中,extra 的值被广播,arr 的值被映射。

JAX 中的外部回调

原文:jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

本指南概述了各种回调函数的用途,这些函数允许 JAX 运行时在主机上执行 Python 代码,即使在jitvmapgrad或其他转换的情况下也是如此。

为什么需要回调?

回调例程是在运行时执行主机端代码的一种方式。举个简单的例子,假设您想在计算过程中打印某个变量的。使用简单的 Python print 语句,如下所示:

import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2) 
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

打印的不是运行时值,而是跟踪时的抽象值(如果您对在 JAX 中的追踪不熟悉,可以在How To Think In JAX找到一个很好的入门教程)。

要在运行时打印值,我们需要一个回调,例如jax.debug.print

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2) 
intermediate value: 3 

通过将由y表示的运行时值传递回主机进程,主机可以打印值。

回调的种类

在早期版本的 JAX 中,只有一种类型的回调可用,即jax.experimental.host_callback中实现的。host_callback例程存在一些缺陷,现已弃用,而现在推荐使用为不同情况设计的几个回调:

  • jax.pure_callback(): 适用于纯函数,即没有副作用的函数。

  • jax.experimental.io_callback(): 适用于不纯的函数,例如读取或写入磁盘数据的函数。

  • jax.debug.callback(): 适用于应反映编译器执行行为的函数。

(我们上面使用的jax.debug.print()函数是jax.debug.callback()的一个包装器)。

从用户角度来看,这三种回调的区别主要在于它们允许什么样的转换和编译器优化。

回调函数 支持返回值 jit vmap grad scan/while_loop 保证执行
jax.pure_callback ❌¹
jax.experimental.io_callback ✅/❌² ✅³
jax.debug.callback

¹ jax.pure_callback可以与custom_jvp一起使用,使其与自动微分兼容。

² 当ordered=False时,jax.experimental.io_callbackvmap兼容。

³ 注意vmapscan/while_loopio_callback具有复杂的语义,并且其行为可能在未来的版本中更改。

探索jax.pure_callback

通常情况下,jax.pure_callback是您在想要执行纯函数的主机端时应使用的回调函数:即没有副作用的函数(如打印值、从磁盘读取数据、更新全局状态等)。

您传递给jax.pure_callback的函数实际上不需要是纯的,但它将被 JAX 的转换和高阶函数假定为纯的,这意味着它可能会被静默地省略或多次调用。

import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x) 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 

因为pure_callback可以省略或复制,它与jitvmap等转换以及像scanwhile_loop这样的高阶原语兼容性开箱即用:""

jax.jit(f)(x) 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 
jax.vmap(f)(x) 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 

然而,由于 JAX 无法审视回调的内容,因此pure_callback具有未定义的自动微分语义:

%xmode minimal 
Exception reporting mode: Minimal 
jax.grad(f)(x) 
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients. 

有关使用pure_callbackjax.custom_jvp的示例,请参见下文示例:pure_callbackcustom_jvp

通过设计传递给pure_callback的函数被视为没有副作用:这意味着如果函数的输出未被使用,编译器可能会完全消除回调:

def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1(); 
printing something 
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2(); 

f1中,回调的输出在函数返回值中被使用,因此执行回调并且我们看到打印的输出。另一方面,在f2中,回调的输出未被使用,因此编译器注意到这一点并消除函数调用。这是对没有副作用的函数回调的正确语义。

探索jax.experimental.io_callback

jax.pure_callback()相比,jax.experimental.io_callback()明确用于与有副作用的函数一起使用,即具有副作用的函数。

例如,这是一个对全局主机端 numpy 随机生成器的回调。这是一个不纯的操作,因为在 numpy 中生成随机数的副作用是更新随机状态(请注意,这只是io_callback的玩具示例,并不一定是在 JAX 中生成随机数的推荐方式!)。

from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x) 
generating float32[5] 
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32) 

io_callback默认与vmap兼容:

jax.vmap(numpy_random_like)(x) 
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[] 
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32) 

但请注意,这可能以任何顺序执行映射的回调。例如,如果在 GPU 上运行此代码,则映射输出的顺序可能会因每次运行而异。

如果保留回调的顺序很重要,可以设置ordered=True,在这种情况下,尝试vmap会引发错误:

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x) 
JaxStackTraceBeforeTransformation: ValueError: Cannot `vmap` ordered IO callback.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError: Cannot `vmap` ordered IO callback. 

另一方面,scanwhile_loop无论是否强制顺序,都与io_callback兼容:

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] 
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[] 
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32) 

pure_callback类似,如果向其传递不同的变量,io_callback在自动微分下会失败:

jax.grad(numpy_random_like)(x) 
JaxStackTraceBeforeTransformation: ValueError: IO callbacks do not support JVP.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

ValueError: IO callbacks do not support JVP. 

然而,如果回调不依赖于不同的变量,它将执行:

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0); 
hello 

pure_callback不同,在此情况下编译器不会消除回调的执行,即使回调的输出在后续计算中未使用。

探索debug.callback

pure_callbackio_callback都对调用的函数的纯度做出了一些假设,并以各种方式限制了 JAX 的变换和编译机制的操作。而debug.callback基本上不对回调函数做出任何假设,因此在程序执行过程中完全反映了 JAX 的操作。此外,debug.callback不能向程序返回任何值。

from jax import debug

def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0); 
log: 1.0 

调试回调兼容vmap

x = jnp.arange(5.0)
jax.vmap(f)(x); 
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0 

也兼容grad和其他自动微分转换。

jax.grad(f)(1.0); 
log: 1.0 

这可以使得debug.callbackpure_callbackio_callback更有用于通用调试。

示例:pure_callbackcustom_jvp

利用jax.pure_callback()的一个强大方式是将其与jax.custom_jvp结合使用(详见自定义导数规则了解更多关于custom_jvp的细节)。假设我们想要为尚未包含在jax.scipyjax.numpy包装器中的 scipy 或 numpy 函数创建一个 JAX 兼容的包装器。

在这里,我们考虑创建一个第一类贝塞尔函数的包装器,该函数实现在scipy.special.jv中。我们可以先定义一个简单的pure_callback

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) 

这使得我们可以从转换后的 JAX 代码中调用scipy.special.jv,包括使用jitvmap转换时:

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0) 
print(j1(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

这里是使用jit得到的相同结果:

print(jax.jit(j1)(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

并且这里再次是使用vmap得到的相同结果:

print(jax.vmap(j1)(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

然而,如果我们调用jax.grad,我们会看到一个错误,因为该函数没有定义自动微分规则:

jax.grad(j1)(z) 
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients. 

让我们为此定义一个自定义梯度规则。查看第一类贝塞尔函数的定义(Bessel Function of the First Kind),我们发现对于其关于参数z的导数有一个相对简单的递推关系:

[\begin{split} d J_\nu(z) = \left{ \begin{eqnarray} -J_1(z),\ &\nu=0\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}]

对于变量 (\nu) 的梯度更加复杂,但由于我们将v参数限制为整数类型,因此在这个例子中,我们不需要担心其梯度。

我们可以使用jax.custom_jvp来为我们的回调函数定义这个自动微分规则:

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz 

现在计算我们函数的梯度将会正确运行:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0)) 
-0.06447162 

此外,由于我们已经根据jv定义了我们的梯度,JAX 的架构意味着我们可以免费获得二阶及更高阶的导数:

jax.hessian(j1)(2.0) 
Array(-0.4003078, dtype=float32, weak_type=True) 

请记住,尽管这在 JAX 中完全正常运作,每次调用基于回调的jv函数都会导致将输入数据从设备传输到主机,并将scipy.special.jv的输出从主机传输回设备。当在 GPU 或 TPU 等加速器上运行时,这种数据传输和主机同步可能会导致每次调用jv时的显著开销。然而,如果您在单个 CPU 上运行 JAX(其中“主机”和“设备”位于同一硬件上),JAX 通常会以快速、零拷贝的方式执行此数据传输,使得这种模式相对直接地扩展了 JAX 的能力。

类型提升语义

原文:jax.readthedocs.io/en/latest/type_promotion.html

此文档描述了 JAX 的类型提升规则,即每对类型的 jax.numpy.promote_types() 结果。关于以下设计考虑的背景,请参阅 Design of Type Promotion Semantics for JAX

JAX 的类型提升行为通过以下类型提升格确定:

_images/type_lattice.svg

其中,例如:

  • b1 表示 np.bool_

  • i2 表示 np.int16

  • u4 表示 np.uint32

  • bf 表示 np.bfloat16

  • f2 表示 np.float16

  • c8 表示 np.complex64

  • i* 表示 Python 的 int 或弱类型的 int

  • f* 表示 Python 的 float 或弱类型的 float,以及

  • c* 表示 Python 的 complex 或弱类型的 complex

(关于弱类型的更多信息,请参阅下文的 JAX 中的弱类型值。)

任意两种类型之间的提升由它们在此格中的 join 决定,生成以下二进制提升表:

b1 u1 u2 u4 u8 i1 i2 i4 i8 bf f2 f4 f8 c8 c16 i* f* c*
b1 b1 u1 u2 u4 u8 i1 i2 i4 i8 bf f2 f4 f8 c8 c16 i* f* c*
u1 u1 u1 u2 u4 u8 i2 i2 i4 i8 bf f2 f4 f8 c8 c16 u1 f* c*
u2 u2 u2 u2 u4 u8 i4 i4 i4 i8 bf f2 f4 f8 c8 c16 u2 f* c*
u4 u4 u4 u4 u4 u8 i8 i8 i8 i8 bf f2 f4 f8 c8 c16 u4 f* c*
u8 u8 u8 u8 u8 u8 f* f* f* f* bf f2 f4 f8 c8 c16 u8 f* c*
i1 i1 i2 i4 i8 f* i1 i2 i4 i8 bf f2 f4 f8 c8 c16 i1 f* c*
i2 i2 i2 i4 i8 f* i2 i2 i4 i8 bf f2 f4 f8 c8 c16 i2 f* c*
i4 i4 i4 i4 i8 f* i4 i4 i4 i8 bf f2 f4 f8 c8 c16 i4 f* c*
i8 i8 i8 i8 i8 f* i8 i8 i8 i8 bf f2 f4 f8 c8 c16 i8 f* c*
bf bf bf bf bf bf bf bf bf bf bf f4 f4 f8 c8 c16 bf bf c8
f2 f2 f2 f2 f2 f2 f2 f2 f2 f2 f4 f2 f4 f8 c8 c16 f2 f2 c8
f4 f4 f4 f4 f4 f4 f4 f4 f4 f4 f4 f4 f4 f8 c8 c16 f4 f4 c8
f8 f8 f8 f8 f8 f8 f8 f8 f8 f8 f8 f8 f8 f8 c16 c16 f8 f8 c16
c8 c8 c8 c8 c8 c8 c8 c8 c8 c8 c8 c8 c8 c16 c8 c16 c8 c8 c8
c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16 c16
i* i* u1 u2 u4 u8 i1 i2 i4 i8 bf f2 f4 f8 c8 c16 i* f* c*
f* f* f* f* f* f* f* f* f* f* bf f2 f4 f8 c8 c16 f* f* c*
c* c* c* c* c* c* c* c* c* c* c8 c8 c8 c16 c8 c16 c* c* c*

JAX 的类型提升规则与 NumPy 的不同,如numpy.promote_types() 所示,在上述表格中以绿色背景标出的单元格中。主要有三类区别:

  • 当将弱类型值与相同类别的 JAX 类型化值进行提升时,JAX 总是偏向于 JAX 值的精度。例如,jnp.int16(1) + 1 将返回 int16 而不是像 NumPy 中那样提升为 int64。请注意,这仅适用于 Python 标量值;如果常量是 NumPy 数组,则使用上述格子结构进行类型提升。例如,jnp.int16(1) + np.array(1) 将返回 int64

  • 当将整数或布尔类型与浮点或复数类型进行提升时,JAX 总是偏向于浮点或复数类型的类型。

  • JAX 支持bfloat16非标准的 16 位浮点类型 (jax.numpy.bfloat16),这对神经网络训练非常有用。唯一显著的提升行为是对 IEEE-754 float16 的处理,其中 bfloat16 提升为 float32

NumPy 和 JAX 之间的差异是因为加速设备(如 GPU 和 TPU)在使用 64 位浮点类型时要么支付显著的性能代价(GPU),要么根本不支持 64 位浮点类型(TPU)。经典 NumPy 的提升规则过于倾向于过度提升到 64 位类型,这对设计用于加速器上运行的系统来说是个问题。

JAX 使用的浮点提升规则更适用于现代加速设备,并且在浮点类型的提升上更为谨慎。JAX 用于浮点类型的提升规则类似于 PyTorch 的规则。

Python 运算符分派的效果

请记住,Python 运算符如加号(+)会根据两个待加值的 Python 类型进行分派。这意味着,例如 np.int16(1) + 1 将按照 NumPy 的规则进行提升,而 jnp.int16(1) + 1 则按照 JAX 的规则进行提升。当两种提升类型结合使用时,可能导致令人困惑的非关联提升语义;例如 np.int16(1) + 1 + jnp.int16(1)

JAX 中的弱类型数值

在大多数情况下,JAX 中的弱类型值可以被视为具有与 Python 标量等效的提升行为,例如以下整数标量 2

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8) 

JAX 的弱类型框架旨在防止在 JAX 值与没有明确用户指定类型的值(如 Python 标量文字)之间的二进制操作中出现不需要的类型提升。例如,如果 2 不被视为弱类型,则上述表达式将导致隐式类型提升。

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32) 

在 JAX 中使用时,Python 标量有时会被提升为DeviceArray对象,例如在 JIT 编译期间。为了在这种情况下保持所需的提升语义,DeviceArray对象携带一个weak_type标志,该标志可以在数组的字符串表示中看到:

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True) 

如果显式指定了dtype,则会导致标准的强类型数组值:

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32) 
```  ## 严格的 dtype 提升

在某些情况下,禁用隐式类型提升行为并要求所有提升都是显式的可能很有用。可以通过在 JAX 中将`jax_numpy_dtype_promotion`标志设置为`'strict'`来实现。在本地,可以通过上下文管理器来完成:

```py
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + y  
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard. 

为了方便起见,严格提升模式仍将允许安全的弱类型提升,因此您仍然可以编写混合使用 JAX 数组和 Python 标量的代码:

>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + 1
>>> print(z)
2.0 

如果您希望全局设置配置,则可以使用标准配置更新:

jax.config.update('jax_numpy_dtype_promotion', 'strict') 

要恢复默认的标准类型提升,请将此配置设置为'standard'

jax.config.update('jax_numpy_dtype_promotion', 'standard') 

Pytrees

原文:jax.readthedocs.io/en/latest/pytrees.html

什么是 pytree?

在 JAX 中,我们使用术语pytree来指代由类似容器的 Python 对象构建的类似树的结构。如果它们在 pytree 注册中,则类被视为容器类,默认包括列表、元组和字典。也就是说:

  1. 任何类型在 pytree 容器注册中的对象被视为 pytree;

  2. 任何类型在 pytree 容器注册中的对象,并且包含 pytrees,被视为 pytree。

对于 pytree 容器注册中的每个条目,注册了类似容器的类型,具有一对函数,用于指定如何将容器类型的实例转换为(children, metadata)对,以及如何将这样的对返回为容器类型的实例。使用这些函数,JAX 可以将任何已注册容器对象的树规范化为元组。

示例 pytrees:

[1, "a", object()]  # 3 leaves

(1, (2, 3), ())  # 3 leaves

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves 

JAX 可以扩展以将其他容器类型视为 pytrees;请参见下面的扩展 pytrees。

Pytrees 和 JAX 函数

许多 JAX 函数,比如 jax.lax.scan(),操作数组的 pytrees。JAX 函数变换可以应用于接受输入和产生输出为数组 pytrees 的函数。

将可选参数应用于 pytrees

某些 JAX 函数变换接受可选参数,用于指定如何处理特定输入或输出值(例如 vmap()in_axesout_axes 参数)。这些参数也可以是 pytrees,它们的结构必须与相应参数的 pytree 结构对应。特别地,在能够“匹配”这些参数 pytrees 中的叶子与参数 pytrees 中的值的情况下,通常限制参数 pytrees 为参数 pytrees 的树前缀。

例如,如果我们将以下输入传递给 vmap()(注意函数的输入参数被视为元组):

(a1, {"k1": a2, "k2": a3}) 

我们可以使用以下 in_axes pytree 指定仅映射k2参数(axis=0),其余参数不映射(axis=None):

(None, {"k1": None, "k2": 0}) 

可选参数 pytree 结构必须与主输入 pytree 相匹配。但是,可选参数可以选择指定为“前缀” pytree,这意味着可以将单个叶值应用于整个子 pytree。例如,如果我们有与上述相同的 vmap() 输入,但希望仅映射字典参数,我们可以使用:

(None, 0)  # equivalent to (None, {"k1": 0, "k2": 0}) 

或者,如果我们希望映射每个参数,可以简单地编写一个应用于整个参数元组 pytree 的单个叶值:

0 

这恰好是vmap()的默认in_axes值!

相同的逻辑适用于指定转换函数的其他可选参数,例如 vmapout_axes

查看对象的 pytree 定义

为了调试目的查看任意对象的 pytree 定义,可以使用:

from jax.tree_util import tree_structure
print(tree_structure(object)) 

开发者信息

这主要是 JAX 内部文档,终端用户不应需要理解这一点来使用 JAX,除非在向 JAX 注册新的用户定义容器类型时。某些细节可能会更改。

内部 pytree 处理

JAX 在api.py边界(以及控制流原语中)将 pytrees 展平为叶子列表。这使得下游 JAX 内部更简单:像grad()jit()vmap()这样的转换可以处理接受并返回各种不同 Python 容器的用户函数,而系统的其他部分可以处理仅接受(多个)数组参数并始终返回扁平数组列表的函数。

当 JAX 展开 pytree 时,它将生成叶子列表和一个treedef对象,该对象编码原始值的结构。然后可以使用treedef来在转换叶子后构造匹配的结构化值。Pytrees 类似于树,而不是 DAG 或图,我们处理它们时假设具有引用透明性并且不能包含引用循环。

这里有一个简单的例子:

from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")

# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}") 
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)] 

默认情况下,pytree 容器可以是列表、元组、字典、命名元组、None、OrderedDict。其他类型的值,包括数值和 ndarray 值,都被视为叶子节点:

from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2),
    Point(1., 2.)
]
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

for structured in example_containers:
  show_example(structured) 
structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0
structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None
structured=Array([0., 0.], dtype=float32)
  flat=[Array([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
  unflattened=Point(x=1.0, y=2.0) 

扩展 pytrees

默认情况下,被视为结构化值的任何部分,如果未被识别为内部 pytree 节点(即类似容器的)则被视为叶子节点:

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

show_example(Special(1., 2.)) 
structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Special(x=1.0, y=2.0) 

被视为内部 pytree 节点的 Python 类型集是可扩展的,通过全局类型注册表,注册类型的值被递归遍历。要注册新类型,可以使用register_pytree_node()

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

 Params:
 v: the value of registered type to flatten.
 Returns:
 a pair of an iterable with the children to be flattened recursively,
 and some opaque auxiliary data to pass back to the unflattening recipe.
 The auxiliary data is stored in the treedef for use during unflattening.
 The auxiliary data could be used, e.g., for dictionary keys.
 """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

 Params:
 aux_data: the opaque data that was specified during flattening of the
 current treedef.
 children: the unflattened children

 Returns:
 a re-constructed object of the registered type, using the specified
 children and auxiliary data.
 """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.)) 
structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0) 

或者,您可以在您的类上定义适当的tree_flattentree_unflatten方法,并使用register_pytree_node_class()进行装饰:

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class RegisteredSpecial2(Special):
  def __repr__(self):
    return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

show_example(RegisteredSpecial2(1., 2.)) 
structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0) 

在定义展开函数时,一般而言children应包含数据结构的所有动态元素(数组、动态标量和 pytrees),而aux_data应包含将被滚入treedef结构的所有静态元素。有时 JAX 需要比较treedef以确保辅助数据在扁平化过程中支持有意义的哈希和相等比较,因此必须小心处理。

操作 pytree 的所有函数都在jax.tree_util中。

自定义 PyTrees 和初始化

用户定义的 PyTree 对象的一个常见问题是,JAX 转换有时会使用意外的值初始化它们,因此初始化时进行的任何输入验证可能会失败。例如:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to MyTree 

在第一种情况下,JAX 的内部使用object()值的数组来推断树的结构;在第二种情况下,将树映射到树的函数的雅可比矩阵定义为树的树。

因此,自定义 PyTree 类的 __init____new__ 方法通常应避免进行任何数组转换或其他输入验证,或者预期并处理这些特殊情况。例如:

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a 

另一个可能性是,结构化你的 tree_unflatten 函数,避免调用 __init__;例如:

def tree_unflatten(aux_data, children):
  del aux_data  # unused in this class
  obj = object.__new__(MyTree)
  obj.a = a
  return obj 

如果你选择这条路线,请确保你的 tree_unflatten 函数在代码更新时与 __init__ 保持同步。

提前降低和编译

原文:jax.readthedocs.io/en/latest/aot.html

JAX 提供了几种转换,如jax.jitjax.pmap,返回一个编译并在加速器或 CPU 上运行的函数。正如 JIT 缩写所示,所有编译都是即时执行的。

有些情况需要进行提前(AOT)编译。当你希望在执行之前完全编译,或者希望控制编译过程的不同部分何时发生时,JAX 为您提供了一些选项。

首先,让我们回顾一下编译的阶段。假设f是由jax.jit()输出的函数/可调用对象,例如对于某个输入可调用对象Ff = jax.jit(F)。当它用参数调用时,例如f(x, y),其中xy是数组,JAX 按顺序执行以下操作:

  1. Stage out原始 Python 可调用F的特殊版本到内部表示。专门化反映了F对从参数xy的属性推断出的输入类型的限制(通常是它们的形状和元素类型)。

  2. Lower这种特殊的阶段计算到 XLA 编译器的输入语言 StableHLO。

  3. Compile降低的 HLO 程序以生成针对目标设备(CPU、GPU 或 TPU)的优化可执行文件。

  4. Execute使用数组xy作为参数执行编译后的可执行文件。

JAX 的 AOT API 允许您直接控制步骤#2、#3 和#4(但不包括#1),以及沿途的一些其他功能。例如:

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
 func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
 %c = stablehlo.constant dense<2> : tensor<i32>
 %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
 %1 = stablehlo.add %0, %arg1 : tensor<i32>
 return %1 : tensor<i32>
 }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True) 

请注意,降低的对象只能在它们被降低的同一进程中使用。有关导出用例,请参阅导出和序列化 API。

有关降低和编译函数提供的功能的更多详细信息,请参见jax.stages文档。

在上面的jax.jit的位置,您还可以lower(...)``jax.pmap()的结果,以及pjitxmap(分别来自jax.experimental.pjitjax.experimental.maps)。在每种情况下,您也可以类似地compile()结果。

所有jit的可选参数——如static_argnums——在相应的降低、编译和执行中都得到尊重。同样适用于pmappjitxmap

在上述示例中,我们可以将lower的参数替换为具有shapedtype属性的任何对象:

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32) 

更一般地说,lower只需其参数结构上提供 JAX 必须了解的内容进行专门化和降低。对于像上面的典型数组参数,这意味着shapedtype字段。相比之下,对于静态参数,JAX 需要实际的数组值(下面会详细说明)。

使用与其降低不兼容的参数调用 AOT 编译函数会引发错误:

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[] 

与此相关的是,AOT 编译函数不能通过 JAX 的即时转换(如jax.jitjax.grad()jax.vmap())进行转换。

使用静态参数进行降低

使用静态参数进行降级强调了传递给jax.jit的选项、传递给lower的参数以及调用生成的编译函数所需的参数之间的交互。继续我们上面的示例:

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
 func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
 %c = stablehlo.constant dense<14> : tensor<i32>
 %0 = stablehlo.add %c, %arg0 : tensor<i32>
 return %0 : tensor<i32>
 }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True) 

lower的结果不能直接序列化以供在不同进程中使用。有关此目的的额外 API,请参见导出和序列化。

注意,这里的lower像往常一样接受两个参数,但随后生成的编译函数仅接受剩余的非静态第二个参数。静态的第一个参数(值为 7)在降级时被视为常量,并内置到降级计算中,其中可能会与其他常量一起折叠。在这种情况下,它的乘以 2 被简化为常量 14。

尽管上面lower的第二个参数可以被一个空的形状/数据类型结构替换,但静态的第一个参数必须是一个具体的值。否则,降级将会出错:

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32) 

AOT 编译的函数不能被转换

编译函数专门针对一组特定的参数“类型”,例如我们正在运行的示例中具有特定形状和元素类型的数组。从 JAX 的内部角度来看,诸如jax.vmap()之类的转换会以一种方式改变函数的类型签名,使得已编译的类型签名失效。作为一项政策,JAX 简单地禁止已编译的函数参与转换。示例:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
 [13., 17., 21.],
 [25., 29., 33.],
 [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'> 

g_aot参与自动微分(例如jax.grad())时也会引发类似的错误。为了一致性,jax.jit的转换也被禁止,尽管jit并没有实质性地修改其参数的类型签名。

调试信息和分析,在可用时

除了主要的 AOT 功能(分离和显式的降级、编译和执行),JAX 的各种 AOT 阶段还提供一些额外的功能,以帮助调试和收集编译器反馈。

例如,正如上面的初始示例所示,降级函数通常提供文本表示。编译函数也是如此,并且还提供来自编译器的成本和内存分析。所有这些都通过jax.stages.Loweredjax.stages.Compiled对象上的方法提供(例如,上面的lowered.as_text()compiled.cost_analysis())。

这些方法旨在帮助手动检查和调试,而不是作为可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这导致了两个重要的注意事项:

  1. 如果某些功能在 JAX 当前的后端上不可用,则其方法将返回某些微不足道的东西(类似于False)。例如,如果支持 JAX 的编译器不提供成本分析,则compiled.cost_analysis()将为None

  2. 如果某些功能可用,则对应方法提供的内容仍然有非常有限的保证。返回值在 JAX 的配置、后端/平台、版本或甚至方法的调用之间,在类型、结构或值上不需要保持一致。JAX 无法保证 compiled.cost_analysis() 在一天的输出将会在随后的一天保持相同。

如果有疑问,请参阅 jax.stages 的包 API 文档。

检查暂停的计算

此笔记顶部列表中的第一个阶段提到专业化和分阶段,之后是降低。JAX 内部对其参数类型专门化的函数的概念,并非始终在内存中具体化为数据结构。要显式构建 JAX 在内部Jaxpr 中间语言中函数专门化的视图,请参见 jax.make_jaxpr()

导出和序列化

jax.readthedocs.io/en/latest/export/index.html

指南

  • 导出和序列化分阶段计算

    • 支持逆向模式自动微分(AD)

    • 兼容性保证

    • 跨平台和多平台导出

    • 形状多态导出

    • 设备多态导出

    • 调用约定版本

    • 从 jax.experimental.export 的迁移指南

  • 形状多态性

    • 形状多态性的正确性

    • 使用维度变量进行计算

  • 与 TensorFlow 的互操作性

导出和序列化分离计算

原文:jax.readthedocs.io/en/latest/export/export.html

提前降级和编译的 API 生成的对象可用于调试或在同一进程中进行编译和执行。有时候,您希望将降级后的 JAX 函数序列化,以便在稍后的时间在单独的进程中进行编译和执行。这将允许您:

  • 在另一个进程或机器上编译并执行该函数,而无需访问 JAX 程序,并且无需重复分离和降低级别,例如在推断系统中。

  • 跟踪和降低一个在没有访问您希望稍后编译和执行该函数的加速器的机器上的函数。

  • 存档 JAX 函数的快照,例如以便稍后能够重现您的结果。注意:请查看此用例的兼容性保证。

这里有一个例子:

>>> import re
>>> import numpy as np
>>> import jax
>>> from jax import export

>>> def f(x): return 2 * x * x

>>> exported: export.Exported = export.export(jax.jit(f))(
...    jax.ShapeDtypeStruct((), np.float32))

>>> # You can inspect the Exported object
>>> exported.fun_name
'f'

>>> exported.in_avals
(ShapedArray(float32[]),)

>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
 func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {

>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()

>>> # The serialized function can later be rehydrated and called from
>>> # another JAX computation, possibly in another process.
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
>>> rehydrated_exp.in_avals
(ShapedArray(float32[]),)

>>> def callee(y):
...  return 3. * rehydrated_exp.call(y * 4.)

>>> callee(1.)
Array(96., dtype=float32) 

序列化分为两个阶段:

  1. 导出以生成一个包含降级函数的 StableHLO 和调用它所需的元数据的 jax.export.Exported 对象。我们计划添加代码以从 TensorFlow 生成 Exported 对象,并使用来自 TensorFlow 和 PyTorch 的 Exported 对象。

  2. 使用 flatbuffers 格式的字节数组进行实际序列化。有关与 TensorFlow 的交互操作的替代序列化,请参阅与 TensorFlow 的互操作性。

支持反向模式 AD

序列化可以选择支持高阶反向模式 AD。这是通过将原始函数的 jax.vjp() 与原始函数一起序列化,直到用户指定的顺序(默认为 0,意味着重新水化的函数无法区分)完成的:

>>> import jax
>>> from jax import export
>>> from typing import Callable

>>> def f(x): return 7 * x * x * x

>>> # Serialize 3 levels of VJP along with the primal function
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
>>> rehydrated_f: Callable = export.deserialize(blob).call

>>> rehydrated_f(0.1)  # 7 * 0.1³
Array(0.007, dtype=float32)

>>> jax.grad(rehydrated_f)(0.1)  # 7*3 * 0.1²
Array(0.21000001, dtype=float32)

>>> jax.grad(jax.grad(rehydrated_f))(0.1)  # 7*3*2 * 0.1
Array(4.2, dtype=float32)

>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1)  # 7*3*2
Array(42., dtype=float32)

>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1)  
Traceback (most recent call last):
ValueError: No VJP is available 

请注意,在序列化时计算 VJP 函数是惰性的,当 JAX 程序仍然可用时。这意味着它遵守 JAX VJP 的所有特性,例如 jax.custom_vjp()jax.remat()

请注意,重新水化的函数不支持任何其他转换,例如前向模式 AD(jvp)或 jax.vmap()

兼容性保证

您不应仅从降低中获取的原始 StableHLO(jax.jit(f).lower(1.).compiler_ir())用于归档和在另一个进程中进行编译,有几个原因。

首先,编译可能使用不同版本的编译器,支持不同版本的 StableHLO。jax.export 模块通过使用 StableHLO 的 可移植工件特性 处理此问题。

自定义调用的兼容性保证

其次,原始的 StableHLO 可能包含引用 C++ 函数的自定义调用。JAX 用于降低少量基元的自定义调用,例如线性代数基元、分片注释或 Pallas 核心。这些不在 StableHLO 的兼容性保证范围内。这些函数的 C++ 实现很少更改,但确实会更改。

jax.export 提供以下导出兼容性保证:JAX 导出的工件可以由编译器和 JAX 运行时系统编译和执行,条件是它们:

  • 比用于导出的 JAX 版本新的长达 6 个月(我们称 JAX 导出提供6 个月的向后兼容性)。如果要归档导出的工件以便稍后编译和执行,这很有用。

  • 比用于导出的 JAX 版本旧的长达 3 周(我们称 JAX 导出提供3 周的向前兼容性)。如果要使用已在导出完成时已部署的消费者编译和运行导出的工件,例如已部署的推断系统。

(特定的兼容性窗口长度与 JAX 对于 jax2tf 所承诺的相同,并基于TensorFlow 的兼容性。术语“向后兼容性”是从消费者的角度,例如推断系统。)

重要的是导出和消费组件的构建时间,而不是导出和编译发生的时间。对于外部 JAX 用户来说,可以在不同版本的 JAX 和 jaxlib 上运行;重要的是 jaxlib 发布的构建时间。

为减少不兼容的可能性,内部 JAX 用户应该:

  • 尽可能频繁地重建和重新部署消费系统

外部用户应该:

  • 尽可能以相同版本的 jaxlib 运行导出和消费系统,并

  • 用最新发布版本的 jaxlib 进行归档导出。

如果绕过 jax.export API 获取 StableHLO 代码,则不适用兼容性保证。

只有部分自定义调用被保证稳定,并具有兼容性保证(参见列表)。我们会持续向允许列表中添加更多自定义调用目标,同时进行向后兼容性测试。如果尝试序列化调用其他自定义调用目标的代码,则在导出期间会收到错误。

如果您希望禁用特定自定义调用的此安全检查,例如目标为 my_target,您可以将 export.DisabledSafetyCheck.custom_call("my_target") 添加到 export 方法的 disabled_checks 参数中,如以下示例所示:

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> from jax._src import core
>>> from jax._src.interpreters import mlir
>>> # Define a new primitive backed by a custom call
>>> new_prim = core.Primitive("new_prim")
>>> _ = new_prim.def_abstract_eval(lambda x: x)
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
 func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
 %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
 return %0 : tensor<f32>
 }
}

>>> # If we try to export, we get an error
>>> export.export(jax.jit(new_prim.bind))(1.)  
Traceback (most recent call last):
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_bind

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
>>> exp = export.export(
...    jax.jit(new_prim.bind),
...    disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_prim")])(1.) 

跨平台和多平台导出

JAX 降级对于少数 JAX 原语是平台特定的。默认情况下,代码将为导出机器上的加速器进行降级和导出:

>>> from jax import export
>>> export.default_export_platform()
'cpu' 

存在一个安全检查,当尝试在没有为其导出代码的加速器的机器上编译 Exported 对象时会引发错误。

您可以明确指定代码应导出到哪些平台。这使您能够在导出时指定不同于您当前可用的加速器,甚至允许您指定多平台导出以获取一个可以在多个平台上编译和执行的Exported对象。

>>> import jax
>>> from jax import export
>>> from jax import lax

>>> # You can specify the export platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), platforms=['tpu'])(1.)

>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.)  
Traceback (most recent call last):
ValueError: Function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.

>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
>>> # that the code lowered will run adequately on the current
>>> # compilation platform (which is the case for `cos` in this
>>> # example):
>>> exp_unsafe = export.export(jax.jit(lax.cos),
...    lowering_platforms=['tpu'],
...    disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)

>>> exp_unsafe.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)

# and similarly with multi-platform lowering
>>> exp_multi = export.export(jax.jit(lax.cos),
...    lowering_platforms=['tpu', 'cpu', 'cuda'])(1.)
>>> exp_multi.call(1.)
Array(0.5403023, dtype=float32, weak_type=True) 

对于多平台导出,StableHLO 将包含多个降级版本,但仅针对那些需要的原语,因此生成的模块大小应该只比具有默认导出的模块稍大一点。作为极端情况,当序列化一个没有任何需要平台特定降级的原语的模块时,您将获得与单平台导出相同的 StableHLO。

>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # A largish function
>>> def f(x):
...   for i in range(1000):
...     x = jnp.cos(x)
...   return x

>>> exp_single = export.export(jax.jit(f))(1.)
>>> len(exp_single.mlir_module_serialized)  
9220

>>> exp_multi = export.export(jax.jit(f),
...                           lowering_platforms=["cpu", "tpu", "cuda"])(1.)
>>> len(exp_multi.mlir_module_serialized)  
9282 

形状多态导出

当在即时编译(JIT)模式下使用时,JAX 将为每个输入形状的组合单独跟踪和降低函数。在导出时,有时可以对某些输入维度使用维度变量,以获取一个可以与多种输入形状组合一起使用的导出物件。

请参阅形状多态文档。

设备多态导出

导出的物件可能包含用于输入、输出和一些中间结果的分片注释,但这些注释不直接引用在导出时存在的实际物理设备。相反,分片注释引用逻辑设备。这意味着您可以在不同于导出时使用的物理设备上编译和运行导出的物件。

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> # Use the first 4 devices for exporting.
>>> export_devices = jax.local_devices()[:4]
>>> export_mesh = Mesh(export_devices, ("a",))
>>> def f(x):
...   return x.T

>>> arg = jnp.arange(8 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)

>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
4

>>> # and it knows the shardings for the inputs. These will be applied
>>> # when the exported is called.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)

>>> res1 = exp.call(jax.device_put(arg,
...                                NamedSharding(export_mesh, P("a"))))

>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_1 index=(slice(8, 16, None),)']

>>> # We can call `exp` with some other 4 devices and another
>>> # mesh with a different shape, as long as the number of devices is
>>> # the same.
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
>>> res2 = exp.call(jax.device_put(arg,
...                                NamedSharding(other_mesh, P("b"))))

>>> # Check out the first 2 shards of the result. Notice that the output is
>>> # sharded similarly; this means that the input was resharded according to the
>>> # exp.in_shardings.
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
 'device=TFRT_CPU_3 index=(slice(8, 16, None),)'] 

尝试使用与导出时不同数量的设备调用导出物件是错误的:

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)

>>> exp.call(arg)  
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device. 

有助于为使用新网格调用导出物件分片输入的辅助函数:

>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P

>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
...   return x.T

>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)

>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))

>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg) 

作为特殊功能,如果一个函数为 1 个设备导出,并且不包含分片注释,则可以在具有相同形状但在多个设备上分片的参数上调用它,并且编译器将适当地分片函数:

```python

>>> import jax

>>> from jax import export

>>> from jax.sharding import Mesh, NamedSharding

>>> from jax.sharding import PartitionSpec as P

>>> def f(x):

...   return jnp.cos(x)

>>> arg = jnp.arange(4)

>>> exp = export.export(jax.jit(f))(arg)

>>> exp.in_avals

(ShapedArray(int32[4]),)

>>> exp.nr_devices

1

>>> # 准备用于调用 `exp` 的网格。

>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))

>>> # Shard the arg according to what `exp` expects.

>>> sharded_arg = jax.device_put(arg,

...                              NamedSharding(calling_mesh, P("b")))

>>> res = exp.call(sharded_arg)

```py

## Calling convention versions

The JAX export support has evolved over time, e.g., to support effects. In order to support compatibility (see compatibility guarantees) we maintain a calling convention version for each `Exported`. As of June 2024, all function exported with version 9 (the latest, see all calling convention versions):

from jax import export

exp: export.Exported = export.export(jnp.cos)(1.)

exp.calling_convention_version

9


At any given time, the export APIs may support a range of calling convention versions. You can control which calling convention version to use using the `--jax-export-calling-convention-version` flag or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable:

from jax import export

(export.minimum_supported_calling_convention_version, export.maximum_supported_calling_convention_version)

(9, 9)

from jax._src import config

with config.jax_export_calling_convention_version(9):

... exp = export.export(jnp.cos)(1.)

... exp.calling_convention_version

9


We reserve the right to remove support for generating or consuming calling convention versions older than 6 months.

### Module calling convention

The `Exported.mlir_module` has a `main` function that takes an optional first platform index argument if the module supports multiple platforms (`len(platforms) > 1`), followed by the token arguments corresponding to the ordered effects, followed by the kept array arguments (corresponding to `module_kept_var_idx` and `in_avals`). The platform index is a i32 or i64 scalar encoding the index of the current compilation platform into the `platforms` sequence.

Inner functions use a different calling convention: an optional platform index argument, optional dimension variable arguments (scalar tensors of type i32 or i64), followed by optional token arguments (in presence of ordered effects), followed by the regular array arguments. The dimension arguments correspond to the dimension variables appearing in the `args_avals`, in sorted order of their names.

Consider the lowering of a function with one array argument of type `f32[w, 2 * h]`, where `w` and `h` are two dimension variables. Assume that we use multi-platform lowering, and we have one ordered effect. The `main` function will be as follows:

func public main(

        platform_index: i32 {jax.global_constant="_platform_index"},

        token_in: token,

        arg: f32[?, ?]) {

    arg_w = hlo.get_dimension_size(arg, 0)

    dim1 = hlo.get_dimension_size(arg, 1)

    arg_h = hlo.floordiv(dim1, 2)

    call _check_shape_assertions(arg)  # See below

    token = new_token()

    token_out, res = call _wrapped_jax_export_main(platform_index,

                                                    arg_h,

                                                    arg_w,

                                                    token_in,

                                                    arg)

    return token_out, res

}

The actual computation is in `_wrapped_jax_export_main`, taking also the values of `h` and `w` dimension variables.

The signature of the `_wrapped_jax_export_main` is:

func private _wrapped_jax_export_main(

    platform_index: i32 {jax.global_constant="_platform_index"},

    arg_h: i32 {jax.global_constant="h"},

    arg_w: i32 {jax.global_constant="w"},

    arg_token: stablehlo.token {jax.token=True},

    arg: f32[?, ?]) -> (stablehlo.token, ...)

Prior to calling convention version 9 the calling convention for effects was different: the `main` function does not take or return a token. Instead the function creates dummy tokens of type `i1[0]` and passes them to the `_wrapped_jax_export_main`. The `_wrapped_jax_export_main` takes dummy tokens of type `i1[0]` and will create internally real tokens to pass to the inner functions. The inner functions use real tokens (both before and after calling convention version 9)

Also starting with calling convention version 9, function arguments that contain the platform index or the dimension variable values have a `jax.global_constant` string attribute whose value is the name of the global constant, either `_platform_index` or a dimension variable name. The global constant name may be empty if it is not known. Some global constant computations use inner functions, e.g., for `floor_divide`. The arguments of such functions have a `jax.global_constant` attribute for all attributes, meaning that the result of the function is also a global constant.

Note that `main` contains a call to `_check_shape_assertions`. JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` have values >= 1\. We must check these constraints when we invoke the module. We use a special custom call `@shape_assertion` that takes a boolean first operand, a string `error_message` attribute that may contain format specifiers `{0}`, `{1}`, …, and a variadic number of integer scalar operands corresponding to the format specifiers.

func private _check_shape_assertions(arg: f32[?, ?]) {

    # Check that w is >= 1

    arg_w = hlo.get_dimension_size(arg, 0)

    custom_call @shape_assertion(arg_w >= 1, arg_w,

        error_message="Dimension variable 'w' must have integer value >= 1\. Found {0}")

    # Check that dim1 is even

    dim1 = hlo.get_dimension_size(arg, 1)

    custom_call @shape_assertion(dim1 % 2 == 0, dim1,

        error_message="Dimension variable 'h' must have integer value >= 1\. Found non-zero remainder {0}")

    # Check that h >= 1

    arg_h = hlo.floordiv(dim1, 2)

    custom_call @shape_assertion(arg_h >= 1, arg_h,

        error_message=""Dimension variable 'h' must have integer value >= 1\. Found {0}")

### Calling convention versions

We list here a history of the calling convention version numbers:

+   Version 1 used MHLO & CHLO to serialize the code, not supported anymore.

+   Version 2 supports StableHLO & CHLO. Used from October 2022\. Not supported anymore.

+   Version 3 supports platform checking and multiple platforms. Used from February 2023\. Not supported anymore.

+   Version 4 supports StableHLO with compatibility guarantees. This is the earliest version at the time of the JAX native serialization launch. Used in JAX from March 15, 2023 (cl/516885716). Starting with March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493). The support for this version was dropped on October 17th, 2023 (cl/573858283).

+   Version 5 adds support for `call_tf_graph`. This is currently used for some specialized use cases. Used in JAX from May 3rd, 2023 (cl/529106145).

+   第 6 版添加了对 `disabled_checks` 属性的支持。此版本要求 `platforms` 属性不为空。自 2023 年 6 月 7 日由 XlaCallModule 支持,自 2023 年 6 月 13 日(JAX 0.4.13)起支持 JAX。

+   第 7 版增加了对 `stablehlo.shape_assertion` 操作和在 `disabled_checks` 中指定的 `shape_assertions` 的支持。参见[形状多态性存在错误](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism)。自 2023 年 7 月 12 日(cl/547482522)由 XlaCallModule 支持,自 2023 年 7 月 20 日(JAX 0.4.14)起支持 JAX 序列化,并自 2023 年 8 月 12 日(JAX 0.4.15)起成为默认选项。

+   第 8 版添加了对 `jax.uses_shape_polymorphism` 模块属性的支持,并仅在该属性存在时启用形状细化传递。自 2023 年 7 月 21 日(cl/549973693)由 XlaCallModule 支持,自 2023 年 7 月 26 日(JAX 0.4.14)起支持 JAX,并自 2023 年 10 月 21 日(JAX 0.4.20)起成为默认选项。

+   第 9 版添加了对 effects 的支持。详见 `export.Exported` 的文档字符串获取准确的调用约定。在此调用约定版本中,我们还使用 `jax.global_constant` 属性标记平台索引和维度变量参数。自 2023 年 10 月 27 日由 XlaCallModule 支持,自 2023 年 10 月 20 日(JAX 0.4.20)起支持 JAX,并自 2024 年 2 月 1 日(JAX 0.4.24)起成为默认选项。截至 2024 年 3 月 27 日,这是唯一支持的版本。

## 从 `jax.experimental.export` 迁移指南。

在 2024 年 6 月 14 日,我们废弃了 `jax.experimental.export` API,采用了 `jax.export` API。有一些小改动:

+   `jax.experimental.export.export`:

    +   旧函数允许任何 Python 可调用对象或 `jax.jit` 的结果。现在仅接受后者。在调用 `export` 前必须手动应用 `jax.jit` 到要导出的函数。

    +   旧的 `lowering_parameters` 关键字参数现在命名为 `platforms`。

+   `jax.experimental.export.default_lowering_platform()` 现在是 `jax.export.default_export_platform()`。

+   `jax.experimental.export.call` 现在是 `jax.export.Exported` 对象的一个方法。不再使用 `export.call(exp)`,应使用 `exp.call`。

+   `jax.experimental.export.serialize` 现在是 `jax.export.Exported` 对象的一个方法。不再使用 `export.serialize(exp)`,应使用 `exp.serialize()`。

+   配置标志 `--jax-serialization-version` 已弃用。使用 `--jax-export-calling-convention-version`。

+   `jax.experimental.export.minimum_supported_serialization_version` 的值现在在 `jax.export.minimum_supported_calling_convention_version`。

+   `jax.export.Exported` 的以下字段已重命名。

    +   `uses_shape_polymorphism` 现在是 `uses_global_constants`。

    +   `mlir_module_serialization_version` 现在是 `calling_convention_version`。

    +   `lowering_platforms` 现在是 `platforms`。
posted @ 2024-06-21 14:07  绝不原创的飞龙  阅读(21)  评论(0编辑  收藏  举报