为什么不推荐使用jax ( jax vs pytorch)—— google推出jax后为什么迟迟没有得到业界接受——jax是否会重蹈TensorFlow的覆辙

在2017年后,Google的TensorFlow在与Facebook的pytorch的竞争中落败,于是为了重夺业内位置,Google在将开放重点从TensorFlow转为新开发一种新的工具框架,那就是jax。虽然在某种意义上来说Google已经放弃了TensorFlow,但是在Google内部依然保持着部分人员再继续维护和开发TensorFlow,但是整个Google在AL方向几乎全部转为了jax,可以说Google是all in在了jax上。


但是jax推出已经有近7年时间了,但是jax这些年里总是有些不温不火,即使是搞AI,搞科学计算,搞大规模计算,搞深度学习的,或许也是有很大一部分人不知道Google的这个jax的,这就和很多人不知道华为的mindspore和百度的paddle一样,但是为什么Google这款替代TensorFlow的框架——jax也成了小众知晓的项目呢,要知道这可是Google,是Google的主打项目,而Google可是一家公司可以几乎挑战全球IT行业的公司,不论是资金财力还是开发水平那都不是国内的百度、阿里、华为这个量级的,但是为什么Google的这款jax却成了如此境地呢。


关于这个问题我是不好发表过多个人观点的,下面就给出一个自己刚看过的一个jax编写的reinforcement learning算法中的PPO算法的源码内容来说下个人的一些理解。


给出一个jax编写的PPO项目,项目地址:

https://github.com/kscalelabs/minppo


image



代码地址:

https://github.com/kscalelabs/minppo/blob/master/minppo/train.py


点击查看代码
"""Train a model with a specified environment module."""
import logging
import os
import pickle
import sys
from typing import Any, Callable, NamedTuple, Sequence
import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from brax.envs import State
from flax.core import FrozenDict
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from minppo.config import Config, load_config_from_cli
from minppo.env import HumanoidEnv
logger = logging.getLogger(__name__)
class Memory(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: Any
class RunnerState(NamedTuple):
train_state: TrainState
env_state: State
last_obs: jnp.ndarray
rng: jnp.ndarray
class UpdateState(NamedTuple):
train_state: TrainState
mem_batch: "Memory"
advantages: jnp.ndarray
targets: jnp.ndarray
rng: jnp.ndarray
class TrainOutput(NamedTuple):
runner_state: RunnerState
metrics: Any
class MLP(nn.Module):
features: Sequence[int]
use_tanh: bool = True
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
for feat in self.features[:-1]:
x = nn.Dense(feat, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
if self.use_tanh:
x = nn.tanh(x)
else:
x = nn.relu(x)
return nn.Dense(self.features[-1], kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x)
class ActorCritic(nn.Module):
num_layers: int
hidden_size: int
action_dim: int
use_tanh: bool = True
@nn.compact
def __call__(self, x: jnp.ndarray) -> tuple[distrax.Distribution, jnp.ndarray]:
actor_mean = MLP([self.hidden_size] * self.num_layers + [self.action_dim], use_tanh=self.use_tanh)(x)
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
critic = MLP([self.hidden_size] * self.num_layers + [1], use_tanh=False)(x)
return pi, jnp.squeeze(critic, axis=-1)
def save_model(params: FrozenDict, filename: str) -> None:
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(params, f)
def make_train(config: Config) -> Callable[[jnp.ndarray], TrainOutput]:
num_updates = config.training.total_timesteps // config.training.num_steps // config.training.num_envs
minibatch_size = config.training.num_envs * config.training.num_steps // config.training.num_minibatches
env = HumanoidEnv(config)
def linear_schedule(count: int) -> float:
# Linear learning rate annealing
frac = 1.0 - (count // (minibatch_size * config.training.update_epochs)) / num_updates
return config.training.lr * frac
def train(rng: jnp.ndarray) -> TrainOutput:
network = ActorCritic(
num_layers=config.model.num_layers,
hidden_size=config.model.hidden_size,
action_dim=env.action_size,
use_tanh=config.model.use_tanh,
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_size)
network_params = network.init(_rng, init_x)
# Set up optimizer with gradient clipping and optional learning rate annealing
if config.training.anneal_lr:
tx = optax.chain(
optax.clip_by_global_norm(config.opt.max_grad_norm),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config.opt.max_grad_norm),
optax.adam(config.opt.lr, eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
# JIT-compile environment functions for performance
@jax.jit
def reset_fn(rng: jnp.ndarray) -> State:
rngs = jax.random.split(rng, config.training.num_envs)
return jax.vmap(env.reset)(rngs)
@jax.jit
def step_fn(states: State, actions: jnp.ndarray, rng: jnp.ndarray) -> State:
return jax.vmap(env.step)(states, actions, rng)
rng, reset_rng = jax.random.split(rng)
env_state = reset_fn(jnp.array(reset_rng))
obs = env_state.obs
def _update_step(
runner_state: RunnerState,
unused: Memory,
) -> tuple[RunnerState, Any]:
def _env_step(
runner_state: RunnerState,
unused: Memory,
) -> tuple[RunnerState, Memory]:
train_state, env_state, last_obs, rng = runner_state
# Sample actions from the policy and evaluate the value function
pi, value = network.apply(train_state.params, last_obs)
rng, action_rng = jax.random.split(rng)
action = pi.sample(seed=action_rng)
log_prob = pi.log_prob(action)
# Step the environment
rng, step_rng = jax.random.split(rng)
step_rngs = jax.random.split(step_rng, config.training.num_envs)
env_state: State = step_fn(env_state, action, step_rngs)
obs = env_state.obs
reward = env_state.reward
done = env_state.done
info = env_state.metrics
# Store experience for later use in PPO updates
memory = Memory(done, action, value, reward, log_prob, last_obs, info)
runner_state = RunnerState(train_state, env_state, obs, rng)
return runner_state, memory
# Collect experience for multiple steps
runner_state, mem_batch = jax.lax.scan(_env_step, runner_state, None, config.rl.num_env_steps)
# Calculate advantages using Generalized Advantage Estimation (GAE)
_, last_val = network.apply(runner_state.train_state.params, runner_state.last_obs)
last_val = jnp.array(last_val)
def _calculate_gae(mem_batch: Memory, last_val: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
def _get_advantages(
gae_and_next_value: tuple[jnp.ndarray, jnp.ndarray], memory: Memory
) -> tuple[tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
gae, next_value = gae_and_next_value
done, value, reward = memory.done, memory.value, memory.reward
# Calculate TD error and GAE
delta = reward + config.rl.gamma * next_value * (1 - done) - value
gae = delta + config.rl.gamma * config.rl.gae_lambda * (1 - done) * gae
return (gae, value), gae
# Reverse-order scan to efficiently compute GAE
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
mem_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + mem_batch.value
advantages, targets = _calculate_gae(mem_batch, last_val)
def _update_epoch(
update_state: UpdateState,
unused: tuple[jnp.ndarray, jnp.ndarray],
) -> tuple[UpdateState, Any]:
def _update_minibatch(
train_state: TrainState, batch_info: tuple[Memory, jnp.ndarray, jnp.ndarray]
) -> tuple[TrainState, Any]:
mem_batch, advantages, targets = batch_info
def _loss_fn(
params: FrozenDict, mem_batch: Memory, gae: jnp.ndarray, targets: jnp.ndarray
) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
# Recompute values to calculate losses
pi, value = network.apply(params, mem_batch.obs)
log_prob = pi.log_prob(mem_batch.action)
# Compute value function loss
value_pred_clipped = mem_batch.value + (value - mem_batch.value).clip(
-config.rl.clip_eps, config.rl.clip_eps
)
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
# Compute policy loss using PPO clipped objective
ratio = jnp.exp(log_prob - mem_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = jnp.clip(ratio, 1.0 - config.rl.clip_eps, 1.0 + config.rl.clip_eps) * gae
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = loss_actor + config.rl.vf_coef * value_loss - config.rl.ent_coef * entropy
return total_loss, (value_loss, loss_actor, entropy)
# Compute gradients and update model
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(train_state.params, mem_batch, advantages, targets)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, mem_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
batch_size = minibatch_size * config.training.num_minibatches
if batch_size != config.training.num_steps * config.training.num_envs:
raise ValueError("`batch_size` must be equal to `num_steps * num_envs`")
# Shuffle and organize data into minibatches
permutation = jax.random.permutation(_rng, batch_size)
batch = (mem_batch, advantages, targets)
batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, [config.training.num_minibatches, -1] + list(x.shape[1:])),
shuffled_batch,
)
# Update model for each minibatch
train_state, total_loss = jax.lax.scan(_update_minibatch, train_state, minibatches)
update_state = UpdateState(train_state, mem_batch, advantages, targets, rng)
return update_state, total_loss
# Perform multiple epochs of updates on collected data
update_state = UpdateState(runner_state.train_state, mem_batch, advantages, targets, runner_state.rng)
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config.training.update_epochs)
runner_state = RunnerState(
train_state=update_state.train_state,
env_state=runner_state.env_state,
last_obs=runner_state.last_obs,
rng=update_state.rng,
)
return runner_state, mem_batch.info
rng, _rng = jax.random.split(rng)
runner_state = RunnerState(train_state, env_state, obs, _rng)
runner_state, metric = jax.lax.scan(_update_step, runner_state, None, num_updates)
return TrainOutput(runner_state=runner_state, metrics=metric)
return train
def main(args: Sequence[str] | None = None) -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
if args is None:
args = sys.argv[1:]
config = load_config_from_cli(args)
logger.info("Configuration loaded")
rng = jax.random.PRNGKey(config.training.seed)
logger.info(f"Random seed set to {config.training.seed}")
train_jit = jax.jit(make_train(config))
logger.info("Training function compiled with JAX")
logger.info("Starting training...")
out = train_jit(rng)
logger.info("Training completed")
logger.info(f"Saving model to {config.training.model_save_path}")
save_model(out.runner_state.train_state.params, config.training.model_save_path)
logger.info("Model saved successfully")
if __name__ == "__main__":
main()

可以看到,jax确实实现了对numpy的融合,并实现了在上面构建深度学习的library,同时在jax基础上也实现了强化学习算法PPO的实现。

可以看到,jax实现的代码确实可以比pytorch实现的PPO算法更加简短,可以说缩短20%的代码量是差不多的。


jax.lax.scan 是可以替代掉python的for循环。


显示的生成随机种子:

jax.random.split(rng)


jax下的flax框架的神经网络定义方式:

image


可以说,jax实现的代码确实比pytorch实现的代码量要少,但是少的又不是那么多,因为python+pytorch的组合已经可以是算法实现的代码量保持在一个比较低的水平了,而将pytorch换做jax后确实还可以减少20%的代码量,但是这时候减少这些代码量好像又不是那么的必要了,但是,最为严重的问题出现了,那就是Google延用开发TensorFlow时的一个技术思路,那就是把最新的高效编码方式和最新的计算方式融入到jax中,但是这种方式却大大的降低了jax的可读性和易用性。

要知道pytorch是一门发展极为缓慢的框架工具,TensorFlow在诞生之初就实现了原生的分布式计算的支持,而pytorch则要比TensorFlow晚近6到8年时间,同样的事TensorFlow融入各种技术,比如多进程,缓存,多队列,grpc,计算图的预编译,等等,各种各个领域的高性能计算的概率都被融入到TensorFlow中了,但是事实证明TensorFlow融入的这些技术并没有提升TensorFlow的性能,反而导致TensorFlow非常的冗余,有各种重复的API,并且使TensorFlow非常的不方便使用,可以说TensorFlow是给python开发的library,但是TensorFlow用起来就特别像C++,而pytorch则像python一样易用,简直可以说pytorch就是和python一样开盒即用。

可以说jax和当年的TensorFlow有异曲同工之妙,那就是搞的技术都很高大上,但是好像用处不太大,性能没见到提升多少,代码量有一定的减少,但是可读性变的极为的差,而就和TensorFlow中的多种API冗余,多个library的冗余一样,jax上构建的各种框架和包也存在过于宽泛的问题,要知道pytorch和python一样,对于一种功能就提供一种实现,力求最为简单易用,而jax则恨不得搞出各种实现,然后不同的实现后面使用不同的技术底层,但是这样的jax就会变得非常的不好用。


可以说jax和TensorFlow一样,都是技术底层十分强大,但是就是不好用,不易用,代码可读性差,不好上手,估计除了Google的公司内部强制要求使用,再加上Google的相关合作商的强制使用以外真的很难得到大规模的推广和使用了。为啥python现在在很多领域比C++/matlab有市场呢,其原因就是好用,易用,代码可读性高,要知道所谓的一定程度缩短代码量或者一定程度上提高运行效率,和易于使用的这个问题上往往需要这个一定程度是要非常高的,不论是TensorFlow和jax都是无法做到这一点的。

可以说3天时间就可以入门的pytorch,但是TensorFlow可能需要3个月到6个月,而jax或许需要一年的时间,那么这种情况下TensorFlow和jax又怎么能打过pytorch呢。



posted on   Angry_Panda  阅读(228)  评论(0编辑  收藏  举报

相关博文:
阅读排行:
· 周边上新:园子的第一款马克杯温暖上架
· 分享 3 个 .NET 开源的文件压缩处理库,助力快速实现文件压缩解压功能!
· Ollama——大语言模型本地部署的极速利器
· DeepSeek如何颠覆传统软件测试?测试工程师会被淘汰吗?
· 使用C#创建一个MCP客户端
历史上的今天:
2021-12-05 mini_imagenet 数据集生成工具 (续)

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

点击右上角即可分享
微信分享提示