JAX-中文文档-二-

JAX 中文文档(二)

原文:jax.readthedocs.io/en/latest/

JAX 教程

原文:jax.readthedocs.io/en/latest/tutorials.html

  • 快速入门
  • 关键概念

  • 即时编译

  • 自动向量化

  • 自动微分

  • 调试入门

  • 伪随机数

  • 使用 pytrees 工作

  • 分片计算入门

  • 有状态计算

关键概念

原文:jax.readthedocs.io/en/latest/key-concepts.html

本节简要介绍了 JAX 包的一些关键概念。

JAX 数组 (jax.Array)

JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它也有一些重要的区别。

数组创建

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()jax.numpy.linspace()jax.numpy.arange() 等。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array) 
True 

如果您在代码中使用 Python 类型注解,jax.Array 是 jax 数组对象的适当注释(参见 jax.typing 以获取更多讨论)。

数组设备和分片

JAX 数组对象具有一个 devices 方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:

x.devices() 
{CpuDevice(id=0)} 

一般来说,数组可能会在多个设备上 分片,您可以通过 sharding 属性进行检查:

x.sharding 
SingleDeviceSharding(device=CpuDevice(id=0)) 

在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换

除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括

  • jax.jit(): 即时(JIT)编译;参见即时编译

  • jax.vmap(): 向量化变换;参见自动向量化

  • jax.grad(): 梯度变换;参见自动微分

以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0)) 
1.05 

通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

jit()vmap()grad() 等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪

变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。

您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x) 
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)> 

打印的值不是数组 x,而是代表 x 的关键属性的 Tracer 实例,比如它的 shapedtype。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()vmap()grad() 可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs

JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。

例如,考虑我们上面定义的 selu 函数:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

我们可以使用 jax.make_jaxpr() 实用程序来将该函数转换为一个 jaxpr,给定特定的输入:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x) 
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) } 

与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees

JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。

以下是一些可以作为 pytrees 处理的对象的示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)] 
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5] 
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0] 

JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map() 可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce() 可以用于在树中的叶子上应用约简操作。

你可以在《使用 pytrees 教程》中了解更多信息。

即时编译

原文:jax.readthedocs.io/en/latest/jit-compilation.html

在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit() 变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。

如何工作 JAX 变换

在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。

查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0)) 
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。

重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x) 的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数副作用 是陌生的术语,这在 🔪 JAX - The Sharp Bits 🔪: Pure Functions 中有稍微详细的解释。

非纯函数很危险,因为在 JAX 变换下它们可能无法按预期运行;它们可能会悄无声息地失败,或者产生意外的下游错误,如泄漏的跟踪器。此外,JAX 通常无法检测到是否存在副作用。(如果需要调试打印,请使用 jax.debug.print()。要表达一般性副作用而牺牲性能,请参阅 jax.experimental.io_callback()。要检查跟踪器泄漏而牺牲性能,请使用 jax.check_tracer_leaks())。

在跟踪时,JAX 通过 跟踪器 对象包装每个参数。这些跟踪器记录了在函数调用期间(即在常规 Python 中发生)对它们执行的所有 JAX 操作。然后,JAX 使用跟踪器记录重构整个函数。重构的输出是 jaxpr。由于跟踪器不记录 Python 的副作用,它们不会出现在 jaxpr 中。但是,副作用仍会在跟踪过程中发生。

注意:Python 的 print() 函数不是纯函数:文本输出是函数的副作用。因此,在跟踪期间,任何 print() 调用都将只发生一次,并且不会出现在 jaxpr 中:

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.)) 
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

看看打印出来的 x 是一个 Traced 对象?这就是 JAX 内部的工作原理。

Python 代码至少运行一次的事实严格来说是一个实现细节,因此不应依赖它。然而,在调试时理解它是有用的,因为您可以在计算的中间值打印出来。

一个关键的理解点是,jaxpr 捕捉函数在给定参数上执行的方式。例如,如果我们有一个 Python 条件语句,jaxpr 只会了解我们选择的分支:

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3]))) 
{ lambda ; a:i32[3]. let  in (a,) } 

JIT 编译函数

正如之前所解释的,JAX 使得操作能够使用相同的代码在 CPU/GPU/TPU 上执行。让我们看一个计算缩放指数线性单元SELU)的例子,这是深度学习中常用的操作:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready() 
2.81 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 

上述代码一次只发送一个操作到加速器。这限制了 XLA 编译器优化我们函数的能力。

自然地,我们希望尽可能多地向 XLA 编译器提供代码,以便它能够完全优化它。为此,JAX 提供了jax.jit()转换,它将即时编译一个与 JAX 兼容的函数。下面的示例展示了如何使用 JIT 加速前述函数。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready() 
1.01 ms ± 2.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

刚刚发生了什么事:

  1. 我们定义了selu_jit作为selu的编译版本。

  2. 我们在x上调用了selu_jit一次。这是 JAX 进行其追踪的地方 - 它需要一些输入来包装成追踪器。然后,jaxpr 使用 XLA 编译成非常高效的代码,针对您的 GPU 或 TPU 进行优化。最后,编译的代码被执行以满足调用。后续对selu_jit的调用将直接使用编译后的代码,跳过 Python 实现。(如果我们没有单独包括预热调用,一切仍将正常运行,但编译时间将包含在基准测试中。因为我们在基准测试中运行多个循环,所以仍会更快,但这不是公平的比较。)

  3. 我们计时了编译版本的执行速度。(注意使用block_until_ready(),这是由于 JAX 的异步调度所需。)

为什么我们不能把所有东西都即时编译(JIT)呢?

在上面的例子中,你可能会想知道我们是否应该简单地对每个函数应用jax.jit()。要理解为什么不是这样,并且何时需要/不需要应用jit,让我们首先检查一些jit不适用的情况。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_1169/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function g at /tmp/ipykernel_1169/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 

在这两种情况下的问题是,我们尝试使用运行时值来条件追踪时间流程。在 JIT 中追踪的值,例如这里的xn,只能通过它们的静态属性(如shapedtype)影响控制流,而不能通过它们的值。有关 Python 控制流与 JAX 交互的更多详细信息,请参见🔪 JAX - The Sharp Bits 🔪: Control Flow

处理这个问题的一种方法是重写代码,避免在值条件上使用条件语句。另一种方法是使用特殊的控制流操作符,例如jax.lax.cond()。然而,有时这并不可行或实际。在这种情况下,可以考虑只对函数的部分进行 JIT 编译。例如,如果函数中最消耗计算资源的部分在循环内部,我们可以只对内部的那部分进行 JIT 编译(但务必查看关于缓存的下一节,以避免出现问题):

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20) 
Array(30, dtype=int32, weak_type=True) 

将参数标记为静态的

如果我们确实需要对具有输入值条件的函数进行 JIT 编译,我们可以告诉 JAX 通过指定static_argnumsstatic_argnames来帮助自己获取特定输入的较少抽象的追踪器。这样做的成本是生成的 jaxpr 和编译的工件依赖于传递的特定值,因此 JAX 将不得不针对指定静态输入的每个新值重新编译函数。只有在函数保证看到有限的静态值集时,这才是一个好策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10)) 
10 
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20)) 
30 

当使用jit作为装饰器时,要指定这些参数的一种常见模式是使用 Python 的functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20)) 
30 

JIT 和缓存

通过第一次 JIT 调用的编译开销,了解jax.jit()如何以及何时缓存先前的编译是有效使用它的关键。

假设我们定义f = jax.jit(g)。当我们首次调用f时,它会被编译,并且生成的 XLA 代码将被缓存。后续调用f将重用缓存的代码。这就是jax.jit如何弥补编译的前期成本。

如果我们指定了static_argnums,那么缓存的代码将仅在标记为静态的参数值相同时使用。如果它们中任何一个发生更改,将重新编译。如果存在许多值,则您的程序可能会花费更多时间进行编译,而不是逐个执行操作。

避免在循环或其他 Python 作用域内定义的临时函数上调用jax.jit()。对于大多数情况,JAX 能够在后续调用jax.jit()时使用编译和缓存的函数。然而,由于缓存依赖于函数的哈希值,在重新定义等价函数时会引发问题。这将导致每次在循环中不必要地重新编译:

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready() 
jit called in a loop with partials:
217 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
219 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.33 ms ± 29.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 

自动向量化

原文:jax.readthedocs.io/en/latest/automatic-vectorization.html

在前一节中,我们讨论了通过jax.jit()函数进行的 JIT 编译。本文档还讨论了 JAX 的另一个转换:通过jax.vmap()进行向量化。

手动向量化

考虑以下简单代码,计算两个一维向量的卷积:

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w) 
Array([11., 20., 29.], dtype=float32) 

假设我们希望将此函数应用于一批权重w到一批向量x

xs = jnp.stack([x, x])
ws = jnp.stack([w, w]) 

最简单的选择是在 Python 中简单地循环遍历批处理:

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

这会产生正确的结果,但效率不高。

为了有效地批处理计算,通常需要手动重写函数,以确保它以向量化形式完成。这并不难实现,但涉及更改函数处理索引、轴和输入其他部分的方式。

例如,我们可以手动重写convolve(),以支持跨批处理维度的向量化计算,如下所示:

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

随着函数复杂性的增加,这种重新实现可能会变得混乱且容易出错;幸运的是,JAX 提供了另一种方法。

自动向量化

在 JAX 中,jax.vmap()转换旨在自动生成这样的函数的向量化实现:

auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

它通过类似于jax.jit()的追踪函数来实现这一点,并在每个输入的开头自动添加批处理轴。

如果批处理维度不是第一维,则可以使用in_axesout_axes参数来指定输入和输出中批处理维度的位置。如果所有输入和输出的批处理轴相同,则可以使用整数,否则可以使用列表。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst) 
Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32) 

jax.vmap()还支持只有一个参数被批处理的情况:例如,如果您希望将一组单一的权重w与一批向量x进行卷积;在这种情况下,in_axes参数可以设置为None

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

结合转换

与所有 JAX 转换一样,jax.jit()jax.vmap()都设计为可组合的,这意味着您可以用jit包装一个 vmapped 函数,或用vmap包装一个 jitted 函数,一切都会正常工作:

jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws) 
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32) 

自动微分

原文:jax.readthedocs.io/en/latest/automatic-differentiation.html

在本节中,您将学习 JAX 中自动微分(autodiff)的基本应用。JAX 具有一个非常通用的自动微分系统。计算梯度是现代机器学习方法的关键部分,本教程将引导您了解一些自动微分的入门主题,例如:

  • 1. 使用 jax.grad 计算梯度

  • 2. 在线性逻辑回归中计算梯度

  • 3. 对嵌套列表、元组和字典进行微分

  • 4. 使用 jax.value_and_grad 评估函数及其梯度

  • 5. 检查数值差异

还要确保查看高级自动微分教程,了解更多高级主题。

虽然理解自动微分的“内部工作原理”对于在大多数情况下使用 JAX 并不关键,但建议您观看这个非常易懂的视频,以深入了解发生的事情。

1. 使用jax.grad()计算梯度

在 JAX 中,您可以使用jax.grad()变换微分一个标量值函数:

import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) 
0.070650816 

jax.grad()接受一个函数并返回一个函数。如果你有一个 Python 函数f,它计算数学函数( f ),那么jax.grad(f)是一个 Python 函数,它计算数学函数( \nabla f )。这意味着grad(f)(x)表示值( \nabla f(x) )。

由于jax.grad()操作函数,您可以将其应用于其自身的输出,以任意次数进行微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0)) 
-0.13621868
0.25265405 

JAX 的自动微分使得计算高阶导数变得容易,因为计算导数的函数本身是可微的。因此,高阶导数就像堆叠转换一样容易。这可以在单变量情况下说明:

函数( f(x) = x³ + 2x² - 3x + 1 )的导数可以计算如下:

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f) 

函数( f )的高阶导数为:

[\begin{split} \begin{array}{l} f'(x) = 3x² + 4x -3\ f''(x) = 6x + 4\ f'''(x) = 6\ f^{iv}(x) = 0 \end{array} \end{split}]

在 JAX 中计算任何这些都像链接jax.grad()函数一样简单:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx) 

在( x=1 )处评估上述内容将给出:

[\begin{split} \begin{array}{l} f'(1) = 4\ f''(1) = 10\ f'''(1) = 6\ f^{iv}(1) = 0 \end{array} \end{split}]

使用 JAX:

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.)) 
4.0
10.0
6.0
0.0 
```  ## 2\. 在线性逻辑回归中计算梯度

下一个示例展示了如何在线性逻辑回归模型中使用`jax.grad()`计算梯度。首先,设置:

```py
key = jax.random.key(0)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ()) 

使用jax.grad()函数及其argnums参数对位置参数进行函数微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}') 
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32) 

jax.grad() API 直接对应于斯皮瓦克经典著作《流形上的微积分》(1965 年)中的优秀符号表示法,也用于苏斯曼和威斯登的《古典力学的结构与解释》(2015 年)及其《函数微分几何》(2013 年)。这两本书都是开放获取的。特别是,《函数微分几何》的“前言”部分为此符号的使用进行了辩护。

实际上,当使用argnums参数时,如果f是用于评估数学函数(f)的 Python 函数,则 Python 表达式jax.grad(f, i)评估为一个用于评估(\partial_i f)的 Python 函数。 ## 3. 对嵌套列表、元组和字典进行微分

由于 JAX 的 PyTree 抽象(详见处理 pytrees),关于标准 Python 容器的微分工作都能正常进行,因此你可以随意使用元组、列表和字典(及任意嵌套结构)。

继续前面的示例:

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b})) 
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)} 

你可以创建自定义的 pytree 节点,以便与不仅仅是jax.grad(),还有其他 JAX 转换(jax.jit()jax.vmap()等)一起使用。 ## 4. 使用jax.value_and_grad评估函数及其梯度

另一个方便的函数是jax.value_and_grad(),可以在一次计算中高效地同时计算函数值和其梯度值。

继续前面的示例:

loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b)) 
loss value 3.0519385
loss value 3.0519385 
```  ## 5\. 对数值差异进行检查

关于导数的一大好处是,它们对有限差异的检查非常直观。

继续前面的示例:

```py
# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec)) 
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117 

JAX 提供了一个简单的便利函数,实际上做了相同的事情,但可以检查任意阶数的微分:

from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives 

下一步

高级自动微分教程提供了关于如何在 JAX 后端实现本文档涵盖的思想的更高级和详细的解释。某些功能,如用于 JAX 可转换 Python 函数的自定义导数规则,依赖于对高级自动微分的理解,因此如果您感兴趣,请查看高级自动微分教程中的相关部分。

调试介绍

原文:jax.readthedocs.io/en/latest/debugging.html

本节介绍了一组内置的 JAX 调试方法 — jax.debug.print()jax.debug.breakpoint()jax.debug.callback() — 您可以将其与各种 JAX 转换一起使用。

让我们从 jax.debug.print() 开始。

JAX 的 debug.print 用于高级别

TL;DR 这是一个经验法则:

  • 对于使用 jax.jit()jax.vmap() 和其他动态数组值的跟踪,使用 jax.debug.print()

  • 对于静态值(例如 dtypes 和数组形状),使用 Python print()

回顾即时编译时,使用 jax.jit() 转换函数时,Python 代码在数组的抽象跟踪器的位置执行。因此,Python print() 函数只会打印此跟踪器值:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.) 
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

Python 的 print 在跟踪时间执行,即在运行时值存在之前。如果要打印实际的运行时值,可以使用 jax.debug.print()

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.) 
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314 

类似地,在 jax.vmap() 内部,使用 Python 的 print 只会打印跟踪器;要打印正在映射的值,请使用 jax.debug.print()

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs) 
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314 

这里是使用 jax.lax.map() 的结果,它是一个顺序映射而不是向量化:

result = jax.lax.map(f, xs) 
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314 

注意顺序不同,如 jax.vmap()jax.lax.map() 以不同方式计算相同结果。在调试时,评估顺序的细节正是您可能需要检查的。

下面是一个关于 jax.grad() 的示例,其中 jax.debug.print() 仅打印前向传递。在这种情况下,行为类似于 Python 的 print(),但如果在调用期间应用 jax.jit(),它是一致的。

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.) 
jax.debug.print(x) -> 1.0 

有时,当参数彼此不依赖时,调用 jax.debug.print() 可能会以不同的顺序打印它们,当使用 JAX 转换进行分阶段时。如果需要原始顺序,例如首先是 x: ... 然后是 y: ...,请添加 ordered=True 参数。

例如:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2) 
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2 
Array(3, dtype=int32, weak_type=True) 

要了解更多关于 jax.debug.print() 及其详细信息,请参阅高级调试。

JAX 的 debug.breakpoint 用于类似 pdb 的调试

TL;DR 使用 jax.debug.breakpoint() 暂停您的 JAX 程序执行以检查值。

要在调试期间暂停编译的 JAX 程序的某些点,您可以使用 jax.debug.breakpoint()。提示类似于 Python 的 pdb,允许您检查调用堆栈中的值。实际上,jax.debug.breakpoint()jax.debug.callback() 的应用,用于捕获有关调用堆栈的信息。

要在 breakpoint 调试会话期间打印所有可用命令,请使用 help 命令。(完整的调试器命令、其强大之处及限制在高级调试中有详细介绍。)

这是调试器会话可能看起来的示例:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution 

JAX 调试器

对于依赖值的断点,您可以使用像jax.lax.cond()这样的运行时条件:

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint 
Array(2., dtype=float32, weak_type=True) 
f(2., 0.) # ==> Pauses during execution 

JAX 调试回调以增强调试期间的控制

jax.debug.print()jax.debug.breakpoint()都使用更灵活的jax.debug.callback()实现,它通过 Python 回调执行主机端逻辑,提供更大的控制。它与jax.jit()jax.vmap()jax.grad()和其他转换兼容(有关更多信息,请参阅外部回调的回调类型表)。

例如:

import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0); 
WARNING:root:Logged value: 1.0 

此回调与其他转换兼容,包括jax.vmap()jax.grad()

x = jnp.arange(5.0)
jax.vmap(f)(x); 
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0 
jax.grad(f)(1.0); 
WARNING:root:Logged value: 1.0 

这使得jax.debug.callback()在通用调试中非常有用。

您可以在外部回调中了解更多关于jax.debug.callback()和其他类型 JAX 回调的信息。

下一步

查看高级调试以了解更多关于在 JAX 中调试的信息。

伪随机数

原文:jax.readthedocs.io/en/latest/random-numbers.html

本节将重点讨论 jax.random 和伪随机数生成(PRNG);即,通过算法生成数列,其特性近似于从适当分布中抽样的随机数列的过程。

PRNG 生成的序列并非真正随机,因为它们实际上由其初始值决定,通常称为 seed,并且每一步的随机抽样都是由从一个样本到下一个样本传递的 state 的确定性函数决定。

伪随机数生成是任何机器学习或科学计算框架的重要组成部分。一般而言,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。

为了更好地理解 JAX 和 NumPy 在随机数生成方法上的差异,我们将在本节中讨论两种方法。

NumPy 中的随机数

NumPy 中的伪随机数生成由 numpy.random 模块本地支持。在 NumPy 中,伪随机数生成基于全局 state,可以使用 numpy.random.seed() 将其设置为确定性初始条件。

import numpy as np
np.random.seed(0) 

您可以使用以下命令检查状态的内容。

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state() 
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ... 

每次对随机函数调用都会更新 state

np.random.seed(0)
print_truncated_random_state() 
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ... 
_ = np.random.uniform()
print_truncated_random_state() 
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ... 

NumPy 允许您在单个函数调用中同时抽取单个数字或整个向量的数字。例如,您可以通过以下方式从均匀分布中抽取一个包含 3 个标量的向量:

np.random.seed(0)
print(np.random.uniform(size=3)) 
[0.5488135  0.71518937 0.60276338] 

NumPy 提供了顺序等效保证,这意味着连续抽取 N 个数字或一次抽样 N 个数字的向量将得到相同的伪随机序列:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3)) 
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338] 

JAX 中的随机数

JAX 的随机数生成与 NumPy 的方式有重要的区别,因为 NumPy 的 PRNG 设计使得同时保证多种理想特性变得困难。具体而言,在 JAX 中,我们希望 PRNG 生成是:

  1. 可复现的,

  2. 可并行化,

  3. 可向量化。

我们将在接下来讨论原因。首先,我们将集中讨论基于全局状态的伪随机数生成设计的影响。考虑以下代码:

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo()) 
1.9791922366721637 

函数 foo 对从均匀分布中抽样的两个标量求和。

如果我们假设 bar()baz() 的执行顺序是可预测的,那么此代码的输出只能满足要求 #1。在 NumPy 中,这不是问题,因为它总是按照 Python 解释器定义的顺序执行代码。然而,在 JAX 中,情况就比较复杂了:为了执行效率,我们希望 JIT 编译器可以自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程需要同步全局状态,这会影响执行效率。

明确的随机状态

为了避免这个问题,JAX 避免使用隐式的全局随机状态,而是通过随机 key 显式地跟踪状态:

from jax import random

key = random.key(42)
print(key) 
Array((), dtype=key<fry>) overlaying:
[ 0 42] 

注意

本节使用由 jax.random.key() 生成的新型类型化 PRNG key,而不是由 jax.random.PRNGKey() 生成的旧型原始 PRNG key。有关详情,请参阅 JEP 9263:类型化 key 和可插拔 RNG。

一个 key 是一个具有特定 PRNG 实现对应的特殊数据类型的数组;在默认实现中,每个 key 由一对 uint32 值支持。

key 实际上是 NumPy 隐藏状态对象的替代品,但我们显式地将其传递给 jax.random() 函数。重要的是,随机函数消耗 key,但不修改它:将相同的 key 对象传递给随机函数将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key)) 
-0.18471177
-0.18471177 

即使使用不同的 random API,重复使用相同的 key 也可能导致相关的输出,这通常是不可取的。

经验法则是:永远不要重复使用 key(除非你希望得到相同的输出)。

为了生成不同且独立的样本,你必须在将 key 传递给随机函数之前显式地调用 split()

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration. 
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592 

(这里调用 del 并非必须,但我们这样做是为了强调一旦使用过的 key 不应再次使用。)

jax.random.split() 是一个确定性函数,它将一个 key 转换为若干独立(在伪随机性意义上)的新 key。我们保留其中一个作为 new_key,可以安全地将额外生成的唯一 subkey 作为随机函数的输入,然后永久丢弃它。如果你需要从正态分布中获取另一个样本,你需要再次执行 split(key),以此类推:关键的一点是,你永远不要重复使用同一个 key

调用 split(key) 的输出的哪一部分被称为 key,哪一部分被称为 subkey 并不重要。它们都是具有相同状态的独立 keykey/subkey 命名约定是一种典型的使用模式,有助于跟踪 key 如何被消耗:subkey 被用于随机函数的直接消耗,而 key 则保留用于稍后生成更多的随机性。

通常,上述示例可以简洁地写成

key, subkey = random.split(key) 

这会自动丢弃旧 key。值得注意的是,split() 不仅可以创建两个 key,还可以创建多个:

key, *forty_two_subkeys = random.split(key, num=43) 

缺乏顺序等价性

NumPy 和 JAX 随机模块之间的另一个区别涉及到上述的顺序等价性保证。

与 NumPy 类似,JAX 的随机模块也允许对向量进行抽样。但是,JAX 不提供顺序等价性保证,因为这样做会干扰 SIMD 硬件上的向量化(上述要求 #3)。

在下面的示例中,使用三个子密钥分别从正态分布中抽取 3 个值,与使用单个密钥并指定shape=(3,)会得到不同的结果:

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,))) 
individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ] 

缺乏顺序等价性使我们能够更高效地编写代码;例如,不用通过顺序循环生成上述的sequence,而是可以使用jax.vmap()以向量化方式计算相同的结果:

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys)) 
vectorized: [-0.04838832  0.10796154 -1.2226542 ] 

下一步

欲了解更多关于 JAX 随机数的信息,请参阅jax.random模块的文档。如果您对 JAX 随机数生成器的设计细节感兴趣,请参阅 JAX PRNG 设计。

处理 pytrees

原文:jax.readthedocs.io/en/latest/working-with-pytrees.html

JAX 内置支持类似字典(dicts)的数组对象,或者列表的列表的字典,或其他嵌套结构 — 在 JAX 中称为 pytrees。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“坑”和模式。

什么是 pytree?

一个 pytree 是由类似容器的 Python 对象构建的容器结构 — “叶子” pytrees 和/或更多的 pytrees。一个 pytree 可以包括列表、元组和字典。一个叶子是任何不是 pytree 的东西,比如一个数组,但一个单独的叶子也是一个 pytree。

在机器学习(ML)的上下文中,一个 pytree 可能包含:

  • 模型参数

  • 数据集条目

  • 强化学习代理观察

当处理数据集时,你经常会遇到 pytrees(比如列表的列表的字典)。

下面是一个简单 pytree 的示例。在 JAX 中,你可以使用 jax.tree.leaves(),从树中提取扁平化的叶子,如此处所示:

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}") 
[1, 'a', <object object at 0x7f3d0048f950>]   has 3 leaves: [1, 'a', <object object at 0x7f3d0048f950>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)] 

在 JAX 中,任何由类似容器的 Python 对象构建的树状结构都可以被视为 pytree。如果它们在 pytree 注册表中,则类被视为容器类,默认情况下包括列表、元组和字典。任何类型不在 pytree 容器注册表中的对象都将被视为树中的叶子节点。

可以通过注册类并使用指定如何扁平化树的函数来扩展 pytree 注册表以包括用户定义的容器类;请参见下面的自定义 pytree 节点。

JAX 提供了许多实用程序来操作 pytrees。这些可以在 jax.tree_util 子包中找到;为了方便起见,其中许多在 jax.tree 模块中有别名。

常见功能:jax.tree.map

最常用的 pytree 函数是 jax.tree.map()。它的工作方式类似于 Python 的原生 map,但透明地操作整个 pytree。

这里有一个例子:

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists) 
[[2, 4, 6], [2, 4], [2, 4, 6, 8]] 

jax.tree.map() 也允许在多个参数上映射一个N-ary函数。例如:

another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists) 
[[2, 4, 6], [2, 4], [2, 4, 6, 8]] 

当使用多个参数与 jax.tree.map() 时,输入的结构必须完全匹配。也就是说,列表必须有相同数量的元素,字典必须有相同的键,等等。

jax.tree.map 示例解释 ML 模型参数

此示例演示了在训练简单多层感知器(MLP)时,pytree 操作如何有用。

从定义初始模型参数开始:

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1]) 

使用 jax.tree.map() 检查初始参数的形状:

jax.tree.map(lambda x: x.shape, params) 
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}] 

接下来,定义训练 MLP 模型的函数:

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  ) 
```  ## 自定义 pytree 节点

本节解释了在 JAX 中如何通过使用 `jax.tree_util.register_pytree_node()` 和 `jax.tree.map()` 扩展将被视为 pytree 内部节点(pytree 节点)的 Python 类型集合。

你为什么需要这个?在前面的示例中,pytrees 被展示为列表、元组和字典,其他所有内容都作为 pytree 叶子。这是因为如果你定义了自己的容器类,它会被视为 pytree 叶子,除非你*注册*它到 JAX。即使你的容器类内部包含树形结构,这个情况也是一样的。例如:

```py
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
]) 
[<__main__.Special at 0x7f3d005a23e0>, <__main__.Special at 0x7f3d005a1960>] 

因此,如果你尝试使用 jax.tree.map() 来期望容器内的元素作为叶子,你会得到一个错误:

jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ]) 
TypeError: unsupported operand type(s) for +: 'Special' and 'int' 

作为解决方案,JAX 允许通过全局类型注册表扩展被视为内部 pytree 节点的类型集合。此外,已注册类型的值被递归地遍历。

首先,使用 jax.tree_util.register_pytree_node() 注册一个新类型:

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

 Params:
 v: The value of the registered type to flatten.
 Returns:
 A pair of an iterable with the children to be flattened recursively,
 and some opaque auxiliary data to pass back to the unflattening recipe.
 The auxiliary data is stored in the treedef for use during unflattening.
 The auxiliary data could be used, for example, for dictionary keys.
 """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

 Params:
 aux_data: The opaque data that was specified during flattening of the
 current tree definition.
 children: The unflattened children

 Returns:
 A reconstructed object of the registered type, using the specified
 children and auxiliary data.
 """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
) 

现在你可以遍历特殊容器结构:

jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ]) 
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)] 

现代 Python 配备了有助于更轻松定义容器的有用工具。一些工具可以直接与 JAX 兼容,但其他的则需要更多关注。

例如,Python 中的 NamedTuple 子类不需要注册即可被视为 pytree 节点类型:

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
]) 
['Alice', 1, 2, 3, 'Bob', 4, 5, 6] 

注意,现在 name 字段出现为一个叶子,因为所有元组元素都是子元素。这就是当你不必费力注册类时会发生的情况。 ## Pytrees 和 JAX 变换

许多 JAX 函数,比如 jax.lax.scan(),操作的是数组的 pytrees。此外,所有 JAX 函数变换都可以应用于接受作为输入和输出的数组 pytrees 的函数。

一些 JAX 函数变换接受可选参数,指定如何处理某些输入或输出值(例如 in_axesout_axes 参数给 jax.vmap())。这些参数也可以是 pytrees,它们的结构必须对应于相应参数的 pytree 结构。特别是为了能够将这些参数 pytree 中的叶子与参数 pytree 中的值匹配起来,“匹配”参数 pytrees 的叶子与参数 pytrees 的值,这些参数 pytrees 通常受到一定限制。

例如,如果你将以下输入传递给 jax.vmap()(请注意,函数的输入参数被视为一个元组):

vmap(f, in_axes=(a1, {"k1": a2, "k2": a3})) 

然后,你可以使用以下 in_axes pytree 来指定仅映射 k2 参数(axis=0),其余不进行映射(axis=None):

vmap(f, in_axes=(None, {"k1": None, "k2": 0})) 

可选参数 pytree 结构必须匹配主输入 pytree 的结构。但是,可选参数可以选择作为“前缀” pytree 指定,这意味着一个单独的叶值可以应用于整个子 pytree。

例如,如果你有与上述相同的 jax.vmap() 输入,但希望仅对字典参数进行映射,你可以使用:

vmap(f, in_axes=(None, 0))  # equivalent to (None, {"k1": 0, "k2": 0}) 

或者,如果希望每个参数都被映射,可以编写一个应用于整个参数元组 pytree 的单个叶值:

vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0}) 

这恰好是jax.vmap()的默认in_axes值。

对于转换函数的特定输入或输出值的其他可选参数,例如jax.vmap()中的out_axes,相同的逻辑也适用于其他可选参数。 ## 显式键路径

在 pytree 中,每个叶子都有一个键路径。叶的键路径是一个list,列表的长度等于叶在 pytree 中的深度。每个是一个hashable 对象,表示对应的 pytree 节点类型中的索引。键的类型取决于 pytree 节点类型;例如,对于dict,键的类型与tuple的键的类型不同。

对于任何 pytree 节点实例的内置 pytree 节点类型,其键集是唯一的。对于具有此属性的节点组成的 pytree,每个叶的键路径都是唯一的。

JAX 提供了以下用于处理键路径的jax.tree_util.*方法:

  • jax.tree_util.tree_flatten_with_path(): 类似于jax.tree.flatten(),但返回键路径。

  • jax.tree_util.tree_map_with_path(): 类似于jax.tree.map(),但函数还接受键路径作为参数。

  • jax.tree_util.keystr(): 给定一个通用键路径,返回一个友好的读取器字符串表达式。

例如,一个用例是打印与某个叶值相关的调试信息:

import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}') 
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo 

要表达键路径,JAX 提供了几种内置 pytree 节点类型的默认键类型,即:

  • SequenceKey(idx: int): 适用于列表和元组。

  • DictKey(key: Hashable): 用于字典。

  • GetAttrKey(name: str): 适用于namedtuple和最好是自定义的 pytree 节点(更多见下一节)

您可以自由地为自定义节点定义自己的键类型。只要它们的__str__()方法也被覆盖为友好的表达式,它们将与jax.tree_util.keystr()一起使用。

for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}') 
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name')) 
```  ## 常见的 pytree 陷阱

本节介绍了在使用 JAX pytrees 时遇到的一些常见问题(“陷阱”)。

### 将 pytree 节点误认为叶子

一个常见的需要注意的问题是意外引入*树节点*而不是*叶子节点*:

```py
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes) 
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))] 

这里发生的是数组的shape是一个元组,它是一个 pytree 节点,其元素是叶子节点。因此,在映射中,不是在例如(2, 3)上调用jnp.ones,而是在23上调用。

解决方案将取决于具体情况,但有两种广泛适用的选项:

  • 重写代码以避免中间jax.tree.map()

  • 将元组转换为 NumPy 数组(np.array)或 JAX NumPy 数组(jnp.array),这样整个序列就成为一个叶子。

jax.tree_utilNone的处理

jax.tree_util 函数将None视为不存在的 pytree 节点,而不是叶子:

jax.tree.leaves([None, None, None]) 
[] 

要将None视为叶子,可以使用is_leaf参数:

jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None) 
[None, None, None] 

自定义 pytrees 和使用意外值进行初始化

另一个与用户定义的 pytree 对象常见的陷阱是,JAX 变换偶尔会使用意外值来初始化它们,因此在初始化时执行的任何输入验证可能会失败。例如:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`. 
TypeError: Cannot interpret '<object object at 0x7f3cce5742a0>' as a data type

The above exception was the direct cause of the following exception:

TypeError: Cannot determine dtype of <object object at 0x7f3cce5742a0>

During handling of the above exception, another exception occurred:

TypeError: Value '<object object at 0x7f3cce5742a0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX. 
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`. 
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:3289: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
  return array(a, dtype=dtype, copy=bool(copy), order=order) 
TypeError: Cannot interpret '<object object at 0x7f3cce574780>' as a data type

The above exception was the direct cause of the following exception:

TypeError: Cannot determine dtype of <object object at 0x7f3cce574780>

During handling of the above exception, another exception occurred:

TypeError: Value '<object object at 0x7f3cce574780>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX. 
  • 在第一种情况下,使用 jax.vmap(...)(tree),JAX 的内部使用 object() 值的数组来推断树的结构。

  • 在第二种情况下,使用 jax.jacobian(...)(tree),将一个将树映射到树的函数的雅可比矩阵定义为树的树。

潜在解决方案 1:

  • 自定义 pytree 类的 __init____new__ 方法通常应避免执行任何数组转换或其他输入验证,或者预期并处理这些特殊情况。例如:
class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a 

潜在解决方案 2:

  • 结构化您的自定义 tree_unflatten 函数,以避免调用 __init__。如果选择这条路线,请确保您的 tree_unflatten 函数在代码更新时与 __init__ 保持同步。例如:
def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj 
```  ## 常见 pytree 模式

本节涵盖了 JAX pytrees 中一些最常见的模式。

### 使用 `jax.tree.map` 和 `jax.tree.transpose` 对 pytree 进行转置

要对 pytree 进行转置(将树的列表转换为列表的树),JAX 提供了两个函数:{func} `jax.tree.map`(更基础)和 `jax.tree.transpose()`(更灵活、复杂且冗长)。

**选项 1:** 使用 `jax.tree.map()`。这里是一个例子:

```py
def tree_transpose(list_of_trees):
  """
 Converts a list of trees of identical structure into a single tree of lists.
 """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps) 
{'obs': [3, 4], 't': [1, 2]} 

选项 2: 对于更复杂的转置,使用 jax.tree.transpose(),它更冗长,但允许您指定更灵活的内部和外部 pytree 结构。例如:

jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
) 
{'obs': [3, 4], 't': [1, 2]} 

分片计算介绍

原文:jax.readthedocs.io/en/latest/sharded-computation.html

本教程介绍了 JAX 中单程序多数据(SPMD)代码的设备并行性。SPMD 是一种并行技术,可以在不同设备上并行运行相同的计算,比如神经网络的前向传播,可以在不同的输入数据上(比如批量中的不同输入)并行运行在不同的设备上,比如几个 GPU 或 Google TPU 上。

本教程涵盖了三种并行计算模式:

  • 通过jax.jit()自动并行化:编译器选择最佳的计算策略(也被称为“编译器接管”)。

  • 使用jax.jit()jax.lax.with_sharding_constraint()半自动并行化

  • 使用jax.experimental.shard_map.shard_map()进行全手动并行化:shard_map可以实现每个设备的代码和显式的通信集合

使用这些 SPMD 的思路,您可以将为一个设备编写的函数转换为可以在多个设备上并行运行的函数。

如果您在 Google Colab 笔记本中运行这些示例,请确保您的硬件加速器是最新的 Google TPU,方法是检查笔记本设置:Runtime > Change runtime type > Hardware accelerator > TPU v2(提供八个可用设备)。

import jax
jax.devices() 
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)] 

关键概念:数据分片

下面列出的所有分布式计算方法的关键是数据分片的概念,描述了如何在可用设备上布置数据。

JAX 如何理解数据在各个设备上的布局?JAX 的数据类型,jax.Array不可变数组数据结构,代表了在一个或多个设备上具有物理存储的数组,并且有助于使并行化成为 JAX 的核心特性。jax.Array对象是专为分布式数据和计算而设计的。每个jax.Array都有一个关联的jax.sharding.Sharding对象,描述了每个全局设备所需的全局数据的分片情况。当您从头开始创建jax.Array时,您还需要创建它的Sharding

在简单的情况下,数组被分片在单个设备上,如下所示:

import jax.numpy as jnp
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices() 
{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)} 
arr.sharding 
SingleDeviceSharding(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)) 

若要更直观地表示存储布局,jax.debug模块提供了一些辅助工具来可视化数组的分片。例如,jax.debug.visualize_array_sharding()显示了数组如何存储在单个设备的内存中:

jax.debug.visualize_array_sharding(arr) 

 TPU 0 

要创建具有非平凡分片的数组,可以为数组定义一个jax.sharding规范,并将其传递给jax.device_put()

在这里,定义一个NamedSharding,它指定了一个带有命名轴的 N 维设备网格,其中jax.sharding.Mesh允许精确的设备放置:

# Pardon the boilerplate; constructing a sharding will become easier in future!
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((2, 4))
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding) 
NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y')) 

将该Sharding对象传递给jax.device_put(),就可以获得一个分片数组:

arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded) 
[[ 0\.  1\.  2\.  3\.  4\.  5\.  6\.  7.]
 [ 8\.  9\. 10\. 11\. 12\. 13\. 14\. 15.]
 [16\. 17\. 18\. 19\. 20\. 21\. 22\. 23.]
 [24\. 25\. 26\. 27\. 28\. 29\. 30\. 31.]] 

 TPU 0   TPU 1       TPU 2    TPU 3 

 TPU 6   TPU 7       TPU 4    TPU 5 

这里的设备编号并不按数字顺序排列,因为网格反映了设备的环形拓扑结构。

1. 通过jit实现自动并行化

一旦您有了分片数据,最简单的并行计算方法就是将数据简单地传递给jax.jit()编译的函数!在 JAX 中,您只需指定希望代码的输入和输出如何分区,编译器将会自动处理:1)内部所有内容的分区;2)跨设备的通信的编译。

jit背后的 XLA 编译器包含了优化跨多个设备的计算的启发式方法。在最简单的情况下,这些启发式方法可以归结为计算跟随数据

为了演示 JAX 中自动并行化的工作原理,下面是一个使用jax.jit()装饰的延迟执行函数的示例:这是一个简单的逐元素函数,其中每个分片的计算将在与该分片关联的设备上执行,并且输出也以相同的方式进行分片:

@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding) 
shardings match: True 

随着计算变得更加复杂,编译器会决定如何最佳地传播数据的分片。

在这里,您沿着x的主轴求和,并可视化结果值如何存储在多个设备上(使用jax.debug.visualize_array_sharding()):

@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result) 
 TPU 0,6 TPU 1,7  TPU 2,4 TPU 3,5 

[48\. 52\. 56\. 60\. 64\. 68\. 72\. 76.] 

结果部分复制:即数组的前两个元素复制到设备06,第二个到17,依此类推。

2. 使用约束进行半自动分片

如果您希望在特定计算中对使用的分片进行一些控制,JAX 提供了with_sharding_constraint()函数。您可以使用jax.lax.with_sharding_constraint()(而不是jax.device_put())与jax.jit()一起更精确地控制编译器如何约束中间值和输出的分布。

例如,假设在上面的f_contract中,您希望输出不是部分复制,而是完全在八个设备上进行分片:

@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  # mesh = jax.create_mesh((8,), 'x')
  devices = mesh_utils.create_device_mesh(8)
  mesh = jax.sharding.Mesh(devices, 'x')
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result) 
 TPU 0  TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4  TPU 5 

[48\. 52\. 56\. 60\. 64\. 68\. 72\. 76.] 

这将为您提供具有所需输出分片的函数。

3. 使用shard_map进行手动并行处理

在上述自动并行化方法中,您可以编写一个函数,就像在操作完整数据集一样,jit将会将该计算分配到多个设备上执行。相比之下,使用jax.experimental.shard_map.shard_map(),您需要编写处理单个数据片段的函数,而shard_map将构建完整的函数。

shard_map的工作方式是在设备mesh上映射函数(shard_map在 shards 上进行映射)。在下面的示例中:

  • 与以往一样,jax.sharding.Mesh允许精确的设备放置,使用轴名称参数来表示逻辑和物理轴名称。

  • in_specs参数确定了分片大小。out_specs参数标识了如何将块重新组装在一起。

注意: 如果需要,jax.experimental.shard_map.shard_map()代码可以在jax.jit()内部工作。

from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr) 
Array([ 1\.        ,  2.682942  ,  2.818595  ,  1.28224   , -0.513605  ,
       -0.9178486 ,  0.44116896,  2.3139732 ,  2.9787164 ,  1.824237  ,
       -0.08804226, -0.99998045, -0.07314599,  1.8403342 ,  2.9812148 ,
        2.3005757 ,  0.42419332, -0.92279506, -0.50197446,  1.2997544 ,
        2.8258905 ,  2.6733112 ,  0.98229736, -0.69244075, -0.81115675,
        0.7352965 ,  2.525117  ,  2.912752  ,  1.5418116 , -0.32726777,
       -0.97606325,  0.19192469], dtype=float32) 

您编写的函数只“看到”数据的单个批次,可以通过打印设备本地形状来检查:

x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) 
global shape: x.shape=(32,)
device local shape: x.shape=(4,) 

因为每个函数只“看到”数据的设备本地部分,这意味着像聚合的函数需要额外的思考。

例如,这是jax.numpy.sum()shard_map的示例:

def f(x):
  return jnp.sum(x, keepdims=True)

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) 
Array([  6,  22,  38,  54,  70,  86, 102, 118], dtype=int32) 

您的函数f分别在每个分片上运行,并且结果的总和反映了这一点。

如果要跨分片进行求和,您需要显式请求,使用像jax.lax.psum()这样的集合操作:

def f(x):
  sum_in_shard = x.sum()
  return jax.lax.psum(sum_in_shard, 'x')

shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) 
Array(496, dtype=int32) 

因为输出不再具有分片维度,所以设置out_specs=P()(请记住,out_specs参数标识如何在shard_map中将块重新组装在一起)。

比较这三种方法

在我们记忆中掌握这些概念后,让我们比较简单神经网络层的三种方法。

首先像这样定义您的规范函数:

@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias) 
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

layer(x, weights, bias) 
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32) 

您可以使用jax.jit()自动以分布式方式运行此操作,并传递适当分片的数据。

如果您以相同的方式分片xweights的主轴,则矩阵乘法将自动并行发生:

P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), 'x')
sharding = jax.sharding.NamedSharding(mesh, P('x'))

x_sharded = jax.device_put(x, sharding)
weights_sharded = jax.device_put(weights, sharding)

layer(x_sharded, weights_sharded, bias) 
Array([0.02138912, 0.893112  , 0.59892005, 0.97742504], dtype=float32) 

或者,您可以在函数中使用jax.lax.with_sharding_constraint()自动分发未分片的输入:

@jax.jit
def layer_auto(x, weights, bias):
  x = jax.lax.with_sharding_constraint(x, sharding)
  weights = jax.lax.with_sharding_constraint(weights, sharding)
  return layer(x, weights, bias)

layer_auto(x, weights, bias)  # pass in unsharded inputs 
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32) 

最后,您可以使用shard_map以相同的方式执行此操作,使用jax.lax.psum()指示矩阵乘积所需的跨分片集合:

from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

layer_sharded(x, weights, bias) 
Array([0.02138914, 0.89311206, 0.5989201 , 0.97742516], dtype=float32) 

下一步

本教程简要介绍了在 JAX 中分片和并行计算的概念。

要深入了解每种 SPMD 方法,请查看以下文档:

  • 分布式数组和自动并行化

  • 使用shard_map进行 SPMD 多设备并行性

posted @ 2024-06-21 14:07  绝不原创的飞龙  阅读(35)  评论(0编辑  收藏  举报