JAX-中文文档-六-

JAX 中文文档(六)

原文:jax.readthedocs.io/en/latest/

高级教程

原文:jax.readthedocs.io/en/latest/advanced_guide.html

本节包含更高级主题的示例和教程,如多核计算、自定义操作及更深入的应用

示例

  • 使用 tensorflow/datasets 进行简单神经网络训练

  • 使用 PyTorch 数据加载进行简单神经网络训练

  • 贝叶斯推断的自动批处理

并行计算

  • 在多主机和多进程环境中使用 JAX

  • 分布式数组和自动并行化

  • 带有 shard_map 的 SPMD 多设备并行性

  • API 规范

  • 集合教程

  • 玩具示例

  • 多主机/多进程环境中的分布式数据加载

  • 带有 xmap 的命名轴和易于修改的并行性

自动微分

  • 自动微分食谱

  • 为可转换为 JAX 的 Python 函数编写自定义导数规则

  • 使用 jax.checkpoint(又名 jax.remat)控制自动微分的保存值

JAX 内部机制

  • JAX 原语的工作原理

  • 在 JAX 中编写自定义 Jaxpr 解释器

  • 使用 C++ 和 CUDA 为 GPU 的自定义操作

  • 检查正确性

深入探讨

  • JAX 中的广义卷积

训练一个简单的神经网络,使用 tensorflow/datasets 进行数据加载

原文:jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

在 Colab 中打开 在 Kaggle 中打开

neural_network_and_data_loading.ipynb 衍生

JAX

让我们结合我们在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算。我们将使用 tensorflow/datasets 数据加载 API 来加载图像和标签(因为它非常出色,世界上不需要再另外一种数据加载库 😛)。

当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,使模型的指定更加即插即用。这里,仅供解释用途,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random 

超参数

让我们先处理一些簿记事项。

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0)) 

自动批量预测

让我们首先定义我们的预测函数。请注意,我们为单个图像示例定义了这个函数。我们将使用 JAX 的 vmap 函数自动处理小批量数据,而不会影响性能。

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits) 

让我们检查我们的预测函数只适用于单个图像。

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape) 
(10,) 
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!') 
Invalid shapes! 
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape) 
(10, 10) 

到目前为止,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了一个自动批处理版本的 predict 函数,应该可以在损失函数中使用。我们应该能够使用 grad 对神经网络参数的损失函数进行求导。最后,我们应该能够使用 jit 加速整个过程。

实用函数和损失函数

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)] 

使用 tensorflow/datasets 进行数据加载

JAX 主要专注于程序转换和支持加速的 NumPy,因此我们不包括数据加载或整理在 JAX 库中。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将使用 tensorflow/datasets 数据加载器。

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels) 
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape) 
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10) 

训练循环

import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc)) 
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341 

我们现在已经使用了大部分 JAX API:grad 用于求导,jit 用于加速和 vmap 用于自动向量化。我们使用 NumPy 来指定所有的计算,并从 tensorflow/datasets 借用了优秀的数据加载器,并在 GPU 上运行了整个过程。

训练一个简单的神经网络,使用 PyTorch 进行数据加载

原文:jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

在 Colab 中打开 在 Kaggle 中打开

版权所有 2018 年 JAX 作者。

根据 Apache 许可证第 2.0 版许可使用本文件;除非符合许可证,否则不得使用本文件。您可以在以下链接获取许可证的副本

https://www.apache.org/licenses/LICENSE-2.0

除非适用法律要求或书面同意,否则在许可证下发布的软件是按“原样”分发的,不提供任何明示或暗示的担保或条件。有关特定语言下的权限和限制,请参阅许可证。

JAX

让我们结合我们在快速入门中展示的一切,来训练一个简单的神经网络。我们将首先使用 JAX 进行计算,指定并训练一个简单的 MLP 来处理 MNIST 数据集。我们将使用 PyTorch 的数据加载 API 加载图像和标签(因为它非常棒,世界上不需要另一个数据加载库)。

当然,您可以使用 JAX 与任何与 NumPy 兼容的 API,以使模型的指定更加即插即用。在这里,仅用于解释目的,我们不会使用任何神经网络库或特殊的 API 来构建我们的模型。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random 

超参数

让我们先处理一些记录事项。

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0)) 

自动批处理预测

让我们首先定义我们的预测函数。请注意,我们正在为单个图像示例定义这个函数。我们将使用 JAX 的vmap函数自动处理小批量,而无需性能损失。

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits) 

让我们检查我们的预测函数是否只适用于单个图像。

# This works on single examples
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape) 
(10,) 
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!') 
Invalid shapes! 
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape) 
(10, 10) 

到这一步,我们已经具备了定义和训练神经网络所需的所有要素。我们已经构建了predict的自动批处理版本,我们应该能够在损失函数中使用它。我们应该能够使用grad来计算损失相对于神经网络参数的导数。最后,我们应该能够使用jit来加速整个过程。

实用工具和损失函数

def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)] 

使用 PyTorch 进行数据加载

JAX 专注于程序转换和支持加速器的 NumPy,因此我们不在 JAX 库中包括数据加载或数据处理。已经有很多出色的数据加载器,所以我们只需使用它们,而不是重新发明轮子。我们将获取 PyTorch 的数据加载器,并制作一个小的 shim 以使其与 NumPy 数组兼容。

!pip  install  torch  torchvision 
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0) 
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32)) 
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) 
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done! 
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets) 
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
  warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
  warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
  warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
  warnings.warn("test_labels has been renamed targets") 

训练循环

import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc)) 
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607 

我们现在已经完全使用了 JAX API:grad 用于求导,jit 用于加速,vmap 用于自动向量化。我们使用 NumPy 来指定所有的计算,借用了 PyTorch 中优秀的数据加载器,并且在 GPU 上运行整个过程。

贝叶斯推断的自动批处理

原文:jax.readthedocs.io/en/latest/notebooks/vmapped_log_probs.html

在 Colab 中打开 在 Kaggle 中打开

本笔记演示了一个简单的贝叶斯推断示例,其中自动批处理使用户代码更易于编写、更易于阅读,减少了错误的可能性。

灵感来自@davmre 的一个笔记本。

import functools
import itertools
import re
import sys
import time

from matplotlib.pyplot import *

import jax

from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random

import numpy as np
import scipy as sp 

生成一个虚拟的二分类数据集

np.random.seed(10009)

num_features = 10
num_points = 100

true_beta = np.random.randn(num_features).astype(jnp.float32)
all_x = np.random.randn(num_points, num_features).astype(jnp.float32)
y = (np.random.rand(num_points) < sp.special.expit(all_x.dot(true_beta))).astype(jnp.int32) 
y 
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32) 

编写模型的对数联合函数

我们将编写一个非批处理版本、一个手动批处理版本和一个自动批处理版本。

非批量化

def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result 
log_joint(np.random.randn(num_features)) 
Array(-213.2356, dtype=float32) 
# This doesn't work, because we didn't write `log_prob()` to handle batching.
try:
  batch_size = 10
  batched_test_beta = np.random.randn(batch_size, num_features)

  log_joint(np.random.randn(batch_size, num_features))
except ValueError as e:
  print("Caught expected exception " + str(e)) 
Caught expected exception Incompatible shapes for broadcasting: shapes=[(100,), (100, 10)] 

手动批处理

def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta.T).T)),
                           axis=-1)
    return result 
batch_size = 10
batched_test_beta = np.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta) 
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32) 

使用 vmap 进行自动批处理

它只是有效地工作。

vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta) 
Array([-147.84033 , -207.02205 , -109.26075 , -243.80833 , -163.0291  ,
       -143.84848 , -160.28773 , -113.771706, -126.60544 , -190.81992 ],      dtype=float32) 

自包含的变分推断示例

从上面复制了一小段代码。

设置(批量化的)对数联合函数

@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `jnp.sum`.
    result = result + jnp.sum(jsp.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + jnp.sum(-jnp.log(1 + jnp.exp(-(2*y-1) * jnp.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint)) 

定义 ELBO 及其梯度

def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
    return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))

elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) 

使用 SGD 优化 ELBO

def normal_sample(key, shape):
  """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)

normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.key(10003)

beta_loc = jnp.zeros(num_features, jnp.float32)
beta_log_scale = jnp.zeros(num_features, jnp.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val)) 
0	-180.8538818359375
10	-113.06045532226562
20	-102.73727416992188
30	-99.787353515625
40	-98.90898132324219
50	-98.29745483398438
60	-98.18632507324219
70	-97.57972717285156
80	-97.28599548339844
90	-97.46996307373047
100	-97.4771728515625
110	-97.5806655883789
120	-97.4943618774414
130	-97.50271606445312
140	-96.86396026611328
150	-97.44197845458984
160	-97.06941223144531
170	-96.84028625488281
180	-97.21336364746094
190	-97.56503295898438
200	-97.26397705078125
210	-97.11979675292969
220	-97.39595031738281
230	-97.16831970214844
240	-97.118408203125
250	-97.24345397949219
260	-97.29788970947266
270	-96.69286346435547
280	-96.96438598632812
290	-97.30055236816406
300	-96.63591766357422
310	-97.0351791381836
320	-97.52909088134766
330	-97.28811645507812
340	-97.07321166992188
350	-97.15619659423828
360	-97.25881958007812
370	-97.19515228271484
380	-97.13092041015625
390	-97.11726379394531
400	-96.938720703125
410	-97.26676940917969
420	-97.35322570800781
430	-97.21007537841797
440	-97.28434753417969
450	-97.1630859375
460	-97.2612533569336
470	-97.21343994140625
480	-97.23997497558594
490	-97.14913940429688
500	-97.23527526855469
510	-96.93419647216797
520	-97.21209716796875
530	-96.82575988769531
540	-97.01284790039062
550	-96.94175720214844
560	-97.16520690917969
570	-97.29165649414062
580	-97.42941284179688
590	-97.24370574951172
600	-97.15222930908203
610	-97.49844360351562
620	-96.9906997680664
630	-96.88956451416016
640	-96.89968872070312
650	-97.13793182373047
660	-97.43705749511719
670	-96.99235534667969
680	-97.15623474121094
690	-97.1869125366211
700	-97.11160278320312
710	-97.78105163574219
720	-97.23226165771484
730	-97.16206359863281
740	-96.99581909179688
750	-96.6672134399414
760	-97.16795349121094
770	-97.51435089111328
780	-97.28900146484375
790	-96.91226196289062
800	-97.17100524902344
810	-97.29047393798828
820	-97.16242980957031
830	-97.19107055664062
840	-97.56382751464844
850	-97.00194549560547
860	-96.86555480957031
870	-96.76338195800781
880	-96.83660888671875
890	-97.12178039550781
900	-97.09554290771484
910	-97.0682373046875
920	-97.11947631835938
930	-96.87930297851562
940	-97.45624542236328
950	-96.69279479980469
960	-97.29376220703125
970	-97.3353042602539
980	-97.34962463378906
990	-97.09675598144531 

显示结果

虽然覆盖率不及理想,但也不错,而且没有人说变分推断是精确的。

figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best') 
<matplotlib.legend.Legend at 0x7f6a2c3c86a0> 

../_images/f3f380106b7365b483cc90c02f9030fe13977e2a0e954dfada1276bb3d3e0444.png

在多主机和多进程环境中使用 JAX

原文:jax.readthedocs.io/en/latest/multi_process.html

介绍

本指南解释了如何在 GPU 集群和Cloud TPU pod 等环境中使用 JAX,在这些环境中,加速器分布在多个 CPU 主机或 JAX 进程上。我们将这些称为“多进程”环境。

本指南专门介绍了如何在多进程设置中使用集体通信操作(例如 jax.lax.psum() ),尽管根据您的用例,其他通信方法也可能有用(例如 RPC,mpi4jax)。如果您尚未熟悉 JAX 的集体操作,建议从分片计算部分开始。在 JAX 的多进程环境中,重要的要求是加速器之间的直接通信链路,例如 Cloud TPU 的高速互连或NCCL 用于 GPU。这些链路允许集体操作在多个进程的加速器上高性能运行。

多进程编程模型

关键概念:

  • 您必须在每个主机上至少运行一个 JAX 进程。

  • 您应该使用 jax.distributed.initialize() 初始化集群。

  • 每个进程都有一组独特的本地设备可以访问。全局设备是所有进程的所有设备集合。

  • 使用标准的 JAX 并行 API,如 jit()(参见分片计算入门教程)和 shard_map()。jax.jit 仅接受全局形状的数组。shard_map 允许您按设备形状进行降级。

  • 确保所有进程按照相同顺序运行相同的并行计算。

  • 确保所有进程具有相同数量的本地设备。

  • 确保所有设备相同(例如,全部为 V100 或全部为 H100)。

启动 JAX 进程

与其他分布式系统不同,其中单个控制节点管理多个工作节点,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为单程序多数据(SPMD)模型。通常,在每个进程中运行相同的 JAX Python 程序,每个进程的执行之间只有轻微差异(例如,不同的进程将加载不同的输入数据)。此外,您必须手动在每个主机上运行您的 JAX 程序! JAX 不会从单个程序调用自动启动多个进程。

(对于多个进程的要求,这就是为什么本指南不作为笔记本提供的原因——我们目前没有好的方法来从单个笔记本管理多个 Python 进程。)

初始化集群

要初始化集群,您应该在每个进程的开始调用 jax.distributed.initialize()jax.distributed.initialize() 必须在程序中的任何 JAX 计算执行之前早些时候调用。

API jax.distributed.initialize() 接受几个参数,即:

  • coordinator_address:集群中进程 0 的 IP 地址,以及该进程上可用的一个端口。进程 0 将启动一个通过该 IP 地址和端口暴露的 JAX 服务,集群中的其他进程将连接到该服务。

  • coordinator_bind_address:集群中进程 0 上的 JAX 服务将绑定到的 IP 地址和端口。默认情况下,它将使用与 coordinator_address 相同端口的所有可用接口进行绑定。

  • num_processes:集群中的进程数

  • process_id:本进程的 ID 号码,范围为[0 .. num_processes)

  • local_device_ids:将当前进程的可见设备限制为 local_device_ids

例如,在 GPU 上,典型用法如下:

import jax

jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
                           num_processes=2,
                           process_id=0) 

在 Cloud TPU、Slurm 和 Open MPI 环境中,你可以简单地调用 jax.distributed.initialize() 而无需参数。参数的默认值将自动选择。在使用 Slurm 和 Open MPI 运行 GPU 时,假定每个 GPU 启动一个进程,即每个进程只分配一个可见本地设备。否则假定每个主机启动一个进程,即每个进程将分配所有本地设备。只有当通过 mpirun/mpiexec 启动 JAX 进程时才会使用 Open MPI 自动初始化。

import jax

jax.distributed.initialize() 

在当前 TPU 上,调用 jax.distributed.initialize() 目前是可选的,但建议使用,因为它启用了额外的检查点和健康检查功能。

本地与全局设备

在开始从您的程序中运行多进程计算之前,了解本地全局设备之间的区别是很重要的。

进程的本地设备是它可以直接寻址和启动计算的设备。 例如,在 GPU 集群上,每个主机只能在直接连接的 GPU 上启动计算。在 Cloud TPU pod 上,每个主机只能在直接连接到该主机的 8 个 TPU 核心上启动计算(有关更多详情,请参阅Cloud TPU 系统架构文档)。你可以通过 jax.local_devices() 查看进程的本地设备。

全局设备是跨所有进程的设备。 一个计算可以跨进程的设备并通过设备之间的直接通信链路执行集体操作,只要每个进程在其本地设备上启动计算即可。你可以通过 jax.devices() 查看所有可用的全局设备。一个进程的本地设备总是全局设备的一个子集。

运行多进程计算

那么,你到底如何运行涉及跨进程通信的计算呢? 使用与单进程中相同的并行评估 API!

例如,shard_map() 可以用于在多个进程间并行计算。(如果您还不熟悉如何使用 shard_map 在单个进程内的多个设备上运行,请参阅分片计算介绍教程。)从概念上讲,这可以被视为在跨主机分片的单个数组上运行 pmap,其中每个主机只“看到”其本地分片的输入和输出。

下面是多进程 pmap 的实际示例:

# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
>>> jax.device_count()  # total number of accelerator devices in the cluster
32
>>> jax.local_device_count()  # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) 

非常重要的是,所有进程以相同的跨进程计算顺序运行。 在每个进程中运行相同的 JAX Python 程序通常就足够了。尽管运行相同程序,但仍需注意可能导致不同顺序计算的一些常见陷阱:

  • 将不同形状的输入传递给同一并行函数的进程可能导致挂起或不正确的返回值。只要它们在进程间产生相同形状的每设备数据分片,不同形状的输入是安全的;例如,传递不同的前导批次大小以在不同的本地设备数上运行是可以的,但是每个进程根据不同的最大示例长度填充其批次是不行的。

  • “最后一批”问题发生在并行函数在(训练)循环中调用时,其中一个或多个进程比其余进程更早退出循环。这将导致其余进程挂起,等待已经完成的进程开始计算。

  • 基于集合的非确定性顺序的条件可能导致代码进程挂起。例如,在当前 Python 版本上遍历 set 或者 Python 3.7 之前的 dict 可能会导致不同进程的顺序不同,即使插入顺序相同也是如此。

分布式数组和自动并行化

原文:jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

在 Colab 中打开 在 Kaggle 中打开

本教程讨论了通过 jax.Array 实现的并行计算,这是 JAX v0.4.1 及更高版本中可用的统一数组对象模型。

import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp 

⚠️ 警告:此笔记本需要 8 个设备才能运行。

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run") 

简介和一个快速示例

通过阅读这本教程笔记本,您将了解 jax.Array,一种用于表示数组的统一数据类型,即使物理存储跨越多个设备。您还将学习如何使用 jax.Arrayjax.jit 结合,实现基于编译器的自动并行化。

在我们逐步思考之前,这里有一个快速示例。首先,我们将创建一个跨多个设备分片的 jax.Array

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding 
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

接下来,我们将对其应用计算,并可视化结果值如何存储在多个设备上:

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

jnp.sin 应用的评估已自动并行化,该应用跨存储输入值(和输出值)的设备:

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready() 
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached 
5 loops, best of 5: 9.69 ms per loop 
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready() 
5 loops, best of 5: 1.86 ms per loop 

现在让我们更详细地查看每个部分!

Sharding 描述了如何将数组值布局在跨设备的内存中。

Sharding 基础知识和 PositionalSharding 子类

要在多个设备上并行计算,我们首先必须在多个设备上布置输入数据。

在 JAX 中,Sharding 对象描述了分布式内存布局。它们可以与 jax.device_put 结合使用,生成具有分布式布局的值。

例如,这里是一个单设备 Sharding 的值:

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192)) 
jax.debug.visualize_array_sharding(x) 
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 

在这里,我们使用 jax.debug.visualize_array_sharding 函数来展示内存中存储值 x 的位置。整个 x 存储在单个设备上,所以可视化效果相当无聊!

但是我们可以通过使用 jax.device_putSharding 对象将 x 分布在多个设备上。首先,我们使用 mesh_utils.create_device_mesh 制作一个 Devicesnumpy.ndarray,该函数考虑了硬件拓扑以确定 Device 的顺序:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,)) 

然后,我们创建一个 PositionalSharding 并与 device_put 一起使用:

from jax.sharding import PositionalSharding

sharding = PositionalSharding(devices)

x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x) 
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘ 

这里的 sharding 是一个 PositionalSharding,它的作用类似于一个具有设备集合作为元素的数组:

sharding 
PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]) 

这里的设备编号不是按数字顺序排列的,因为网格反映了设备的基础环形拓扑结构。

通过编写 PositionalSharding(ndarray_of_devices),我们确定了设备顺序和初始形状。然后我们可以对其进行重新形状化:

sharding.reshape(8, 1) 
PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]]) 
sharding.reshape(4, 2) 
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]]) 

要使用device_put与数据数组x,我们可以将sharding重新形状为与x.shape同余的形状,这意味着具有与x.shape相同长度的形状,并且其中每个元素均匀地分割对应x.shape的元素:

def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
  return (len(x_shape) == len(sharding_shape) and
          all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape))) 

例如,我们可以将sharding重新形状为(4, 2),然后在device_put中使用它:

sharding = sharding.reshape(4, 2)
print(sharding) 
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]]) 
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里的y代表与x相同的,但其片段(即切片)存储在不同设备的内存中。

不同的PositionalSharding形状会导致结果的不同分布布局(即分片):

sharding = sharding.reshape(1, 8)
print(sharding) 
PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]]) 
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

在某些情况下,我们不只是想将x的每个切片存储在单个设备的内存中;我们可能希望在多个设备的内存中复制一些切片,即在多个设备的内存中存储切片的值。

使用PositionalSharding,我们可以通过调用 reducer 方法replicate来表达复制:

sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True)) 
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]]) 
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 

这里的可视化显示了x沿其第二维以两种方式分片(而不沿第一维分片),每个片段都复制了四种方式(即存储在四个设备内存中)。

replicate方法类似于熟悉的 NumPy 数组缩减方法,如.sum().prod()。它沿着一个轴执行集合并操作。因此,如果sharding的形状为(4, 2),那么sharding.replicate(0, keepdims=True)的形状为(1, 2)sharding.replicate(1, keepdims=True)的形状为(4, 1)。与 NumPy 方法不同,keepdims=True实际上是默认的,因此减少的轴不会被压缩:

print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape) 
(1, 2)
(4, 1) 
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y) 
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘ 

NamedSharding提供了一种使用名称表达分片的方式。

到目前为止,我们已经使用了PositionalSharding,但还有其他表达分片的替代方法。实际上,Sharding是一个接口,任何实现该接口的类都可以与device_put等函数一起使用。

另一种方便的表达分片的方法是使用NamedSharding

from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

我们可以定义一个辅助函数使事情更简单:

devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec) 
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

在这里,我们使用P('a', 'b')来表达x的第一和第二轴应该分片到设备网格轴'a''b'上。我们可以轻松切换到P('b', 'a')以在不同设备上分片x的轴:

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y) 
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘ 

这里,因为P('a', None)没有提及Mesh轴名'b',我们在轴'b'上得到了复制。这里的None只是一个占位符,用于与值x的第二轴对齐,而不表示在任何网格轴上进行分片。(简写方式是,尾部的None可以省略,因此P('a', None)的意思与P('a')相同。但是明确说明并不会有害!)

要仅在x的第二轴上进行分片,我们可以在PartitionSpec中使用None占位符。

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 

对于固定的网格,我们甚至可以将x的一个逻辑轴分割到多个设备网格轴上:

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y) 
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘ 

使用NamedSharding可以轻松定义一次设备网格并为其轴命名,然后只需在需要时在每个device_putPartitionSpec中引用这些名称。

计算遵循数据分片并自动并行化

使用分片输入数据,编译器可以给我们并行计算。特别是,用 jax.jit 装饰的函数可以在分片数组上操作,而无需将数据复制到单个设备上。相反,计算遵循分片:基于输入数据的分片,编译器决定中间结果和输出值的分片,并并行评估它们,必要时甚至插入通信操作。

例如,最简单的计算是逐元素的:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y) 
input sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
output sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里对于逐元素操作 jnp.sin,编译器选择了输出分片与输入相同。此外,编译器自动并行化计算,因此每个设备都可以并行计算其输出片段。

换句话说,即使我们将 jnp.sin 的计算写成单台机器执行,编译器也会为我们拆分计算并在多个设备上执行。

我们不仅可以对逐元素操作执行相同操作。考虑使用分片输入的矩阵乘法:

y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w) 
lhs sharding:
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
rhs sharding:
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
out sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里编译器选择了输出分片,以便最大化并行计算:无需通信,每个设备已经具有计算其输出分片所需的输入分片。

我们如何确保它实际上是并行运行的?我们可以进行简单的时间实验:

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single) 
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z)) 
True 
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready() 
5 loops, best of 5: 19.3 ms per loop 
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready() 
5 loops, best of 5: 3.25 ms per loop 

即使复制一个分片的 Array,也会产生具有输入分片的结果:

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

因此,当我们使用 jax.device_put 明确分片数据并对该数据应用函数时,编译器会尝试并行化计算并决定输出分片。这种对分片数据的策略是JAX 遵循显式设备放置策略的泛化

当明确分片不一致时,JAX 会报错

但是如果计算的两个参数在不同的设备组上明确放置,或者设备顺序不兼容,会发生错误:

import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red')
  print(textwrap.fill(f'{name}: {str(e)}')) 
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])

y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try: y + z
except ValueError as e: print_exception(e) 
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3] on platform TPU and
another array's device ids [4, 5, 6, 7] on platform TPU 
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)

y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))
try: y + z
except ValueError as e: print_exception(e) 
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform
TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on
platform TPU 

我们说通过 jax.device_put 明确放置或分片的数组已经锁定在它们的设备上,因此不会自动移动。请查看 设备放置常见问题解答 获取更多信息。

当数组没有使用 jax.device_put 明确放置或分片时,它们会放置在默认设备上并未锁定。与已锁定数组不同,未锁定数组可以自动移动和重新分片:也就是说,未锁定数组可以作为计算的参数,即使其他参数明确放置在不同的设备上。

例如,jnp.zerosjnp.arangejnp.array 的输出都是未锁定的:

y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!') 
no error! 

限制在 jit 代码中的中间片段

虽然编译器将尝试决定函数的中间值和输出应如何分片,但我们还可以使用 jax.lax.with_sharding_constraint 来给它提供提示。使用 jax.lax.with_sharding_constraint 类似于 jax.device_put,不同之处在于我们在分阶段函数(即 jit 装饰的函数)内部使用它:

sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2)) 
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y 
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.replicate())
  return y 
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│  TPU 0,1,2,3,4,5,6,7  │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 

通过添加 with_sharding_constraint,我们限制了输出的分片。除了尊重特定中间变量的注释外,编译器还会使用注释来决定其他值的分片。

经常的好做法是注释计算的输出,例如根据值最终如何被使用来注释它们。

示例:神经网络

⚠️ 警告:以下内容旨在简单演示使用 jax.Array 进行自动分片传播,但可能不反映实际示例的最佳实践。 例如,实际示例可能需要更多使用 with_sharding_constraint

我们可以利用 jax.device_putjax.jit 的计算跟随分片特性来并行化神经网络中的计算。以下是基于这种基本神经网络的一些简单示例:

import jax
import jax.numpy as jnp 
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) 
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss)) 
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b

def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

    return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) 

8 路批数据并行

sharding = PositionalSharding(jax.devices()).reshape(8, 1) 
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate()) 
loss_jit(params, batch) 
Array(23.469475, dtype=float32) 
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch)) 
10.760101 
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready() 
5 loops, best of 5: 26.3 ms per loop 
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0]) 
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready() 
5 loops, best of 5: 122 ms per loop 

4 路批数据并行和 2 路模型张量并行

sharding = sharding.reshape(4, 2) 
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1]) 
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘ 
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())

W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))

W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())

W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) 
jax.debug.visualize_array_sharding(W2) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 
jax.debug.visualize_array_sharding(W3) 
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘ 
print(loss_jit(params, batch)) 
10.760103 
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)] 
print(loss_jit(params, batch)) 
10.752466 
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘ 
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() 
10 loops, best of 10: 30.5 ms per loop 

锐利的部分

生成随机数

JAX 自带一个功能强大且确定性的 随机数生成器。它支持 jax.random 模块 中的各种采样函数,如 jax.random.uniform

JAX 的随机数是由基于计数器的 PRNG 生成的,因此原则上,随机数生成应该是对计数器值的纯映射。原则上,纯映射是一个可以轻松分片的操作。它不应需要跨设备通信,也不应需要设备间的冗余计算。

然而,由于历史原因,现有的稳定 RNG 实现并非自动可分片。

考虑以下示例,其中一个函数绘制随机均匀数并将其逐元素添加到输入中:

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding) 

在分区输入上,函数 f 生成的输出也是分区的:

jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

但是,如果我们检查 f 在这个分区输入上的编译计算,我们会发现它确实涉及一些通信:

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? True 

解决这个问题的一种方法是使用实验性升级标志 jax_threefry_partitionable 配置 JAX。启用该标志后,编译计算中的“集体排列”操作现在已经消失:

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? False 

输出仍然是分区的:

jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

然而,jax_threefry_partitionable 选项的一个注意事项是,即使是由相同随机密钥生成的,使用该标志设置后生成的随机值可能与未设置标志时不同

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x)) 
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ] 

jax_threefry_partitionable 模式下,JAX 的 PRNG 保持确定性,但其实现是新的(并且正在开发中)。为给定密钥生成的随机值在特定的 JAX 版本(或 main 分支上的特定提交)中将保持相同,但在不同版本之间可能会有所变化。

posted @ 2024-06-21 14:07  绝不原创的飞龙  阅读(13)  评论(0编辑  收藏  举报