JAX介绍和快速入门示例

JAX 是一个由 Google 开发的用于优化科学计算Python 库:

  • 它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。
  • 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。
  • 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。
  • 它对自动微分有很好的支持,对机器学习研究很有用。可以使用 jax.grad() 触发自动区分。
  • JAX 鼓励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。
  • JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT . JAX()用于JIT编译和加速代码,JIT .grad()用于求导,以及JIT .vmap()用于自动向量化或批处理。
  • JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。

JAX 使用 JIT 编译有两种方式:

  • 自动:在执行 JAX 函数的库调用时,默认情况下 JIT 编译会在后台进行。
  • 手动:您可以使用 jax.jit() 手动请求对自己的 Python 函数进行 JIT 编译。

JAX 使用示例

我们可以使用 pip 安装库。

  1. pip install jax

导入需要的包,这里我们也继续使用 NumPy ,这样可以执行一些基准测试。

  1. import jax
  2. import jax.numpy as jnp
  3. from jax import random
  4. from jax import grad, jit
  5. import numpy as np
  6. key = random.PRNGKey(0)

与 import numpy as np 类似,我们可以 import jax.numpy as jnp 并将代码中的所有 np 替换为 jnp 。如果 NumPy 代码是用函数式编程风格编写的,那么新的 JAX 代码就可以直接使用。但是,如果有可用的GPU,JAX则可以直接使用。

JAX 中随机数的生成方式与 NumPy 不同。JAX需要创建一个 jax.random.PRNGKey 。我们稍后会看到如何使用它。

我们在 Google Colab 上做一个简单的基准测试,这样我们就可以轻松访问 GPU 和 TPU。我们首先初始化一个包含 25M 元素的随机矩阵,然后将其乘以它的转置。使用针对 CPU 优化的 NumPy,矩阵乘法平均需要 1.61 秒。

  1. # runs on CPU - numpy
  2. size = 5000
  3. x = np.random.normal(size=(size, size)).astype(np.float32)
  4. %timeit np.dot(x, x.T)
  5. # 1 loop, best of 5: 1.61 s per loop

在 CPU 上使用 JAX 执行相同的操作平均需要大约 3.49 秒。

  1. # runs on CPU - JAX
  2. size = 5000
  3. x = random.normal(key, (size, size), dtype=jnp.float32)
  4. %timeit jnp.dot(x, x.T).block_until_ready()
  5. # 1 loop, best of 5: 3.49 s per loop

在 CPU 上运行时,JAX 通常比 NumPy 慢,因为 NumPy 已针对CPU进行了非常多的优化。但是,当使用加速器时这种情况会发生变化,所以让我们尝试使用 GPU 进行矩阵乘法。

完整文章

https://avoid.overfit.cn/post/589106b6f0a0431480a42a1bf399e81e

posted @ 2022-06-06 11:16  deephub  阅读(1496)  评论(0编辑  收藏  举报