谷歌对深度学习感到兴奋的 Jax 是什么?
谷歌对深度学习感到兴奋的 Jax 是什么?
大概 你听说过 TensorFlow 和 PyTorch 吗? 这是另一个较新的机器学习框架: 贾克斯。
那么 JAX 有什么令人兴奋的地方呢?
什么是 JAX?
对于深度神经网络,最小化每个参数中的损失函数并执行梯度下降可能会非常消耗资源。传统方法涉及手动推导和编码,或使用 TensorFlow 等机器学习框架的句法和语义约束应用神经模型。
如果可以使用 NumPy 库简单地编写损失函数并自动完成工作会怎样?这正是 Jax 所做的。
JAX 是一个类似于 NumPy 的库,但具有一些更强大的功能:
JAX 使用 XLA 在 GPU 和 TPU 等加速器上编译和运行 NumPy 代码。编译默认在后台进行,实时 (JIT) 编译和库调用执行。
JAX 允许您使用单线程 API 在 XLA 优化内核上即时编译您自己的 Python 函数。编译和自动微分可以按需构建,因此您可以表达复杂的算法并获得最佳性能,而无需离开 Python。
用于机器学习的 JAX 和 TensorFlow
JAX 和 TensorFlow 由 Google 编写。开发 JAX 似乎更容易一些,可以说更直观。
然而,JAX 缺乏 TensorFlow 多年来构建的广泛基础设施,无论是开源项目、预训练模型、教程、更高级别的抽象(通过 Keras)以及对部署目标的可移植性。
用于机器学习的 JAX 和 PyTorch
另一个接近 JAX 的机器学习框架是 PyTorch。
JAX 具有较低级别函数定义的功能使其更适合某些研究任务。
但是,PyTorch 提供了更多的库和实用程序、预训练和预编写的网络定义、数据加载器以及对部署目标的可移植性。
如果您在研究领域工作,JAX 是您项目的不错选择。
如果您正在积极开发应用程序,那么使用 PyTorch 或 TensorFlow 框架将使您的初创公司发展得更快。
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明