为什么不推荐使用jax ( jax vs pytorch)—— google推出jax后为什么迟迟没有得到业界接受——jax是否会重蹈TensorFlow的覆辙
在2017年后,Google的TensorFlow在与Facebook的pytorch的竞争中落败,于是为了重夺业内位置,Google在将开放重点从TensorFlow转为新开发一种新的工具框架,那就是jax。虽然在某种意义上来说Google已经放弃了TensorFlow,但是在Google内部依然保持着部分人员再继续维护和开发TensorFlow,但是整个Google在AL方向几乎全部转为了jax,可以说Google是all in在了jax上。
但是jax推出已经有近7年时间了,但是jax这些年里总是有些不温不火,即使是搞AI,搞科学计算,搞大规模计算,搞深度学习的,或许也是有很大一部分人不知道Google的这个jax的,这就和很多人不知道华为的mindspore和百度的paddle一样,但是为什么Google这款替代TensorFlow的框架——jax也成了小众知晓的项目呢,要知道这可是Google,是Google的主打项目,而Google可是一家公司可以几乎挑战全球IT行业的公司,不论是资金财力还是开发水平那都不是国内的百度、阿里、华为这个量级的,但是为什么Google的这款jax却成了如此境地呢。
关于这个问题我是不好发表过多个人观点的,下面就给出一个自己刚看过的一个jax编写的reinforcement learning算法中的PPO算法的源码内容来说下个人的一些理解。
给出一个jax编写的PPO项目,项目地址:
https://github.com/kscalelabs/minppo
代码地址:
https://github.com/kscalelabs/minppo/blob/master/minppo/train.py
点击查看代码
"""Train a model with a specified environment module.""" import logging import os import pickle import sys from typing import Any, Callable, NamedTuple, Sequence import distrax import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import optax from brax.envs import State from flax.core import FrozenDict from flax.linen.initializers import constant, orthogonal from flax.training.train_state import TrainState from minppo.config import Config, load_config_from_cli from minppo.env import HumanoidEnv logger = logging.getLogger(__name__) class Memory(NamedTuple): done: jnp.ndarray action: jnp.ndarray value: jnp.ndarray reward: jnp.ndarray log_prob: jnp.ndarray obs: jnp.ndarray info: Any class RunnerState(NamedTuple): train_state: TrainState env_state: State last_obs: jnp.ndarray rng: jnp.ndarray class UpdateState(NamedTuple): train_state: TrainState mem_batch: "Memory" advantages: jnp.ndarray targets: jnp.ndarray rng: jnp.ndarray class TrainOutput(NamedTuple): runner_state: RunnerState metrics: Any class MLP(nn.Module): features: Sequence[int] use_tanh: bool = True @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: for feat in self.features[:-1]: x = nn.Dense(feat, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) if self.use_tanh: x = nn.tanh(x) else: x = nn.relu(x) return nn.Dense(self.features[-1], kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) class ActorCritic(nn.Module): num_layers: int hidden_size: int action_dim: int use_tanh: bool = True @nn.compact def __call__(self, x: jnp.ndarray) -> tuple[distrax.Distribution, jnp.ndarray]: actor_mean = MLP([self.hidden_size] * self.num_layers + [self.action_dim], use_tanh=self.use_tanh)(x) actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) critic = MLP([self.hidden_size] * self.num_layers + [1], use_tanh=False)(x) return pi, jnp.squeeze(critic, axis=-1) def save_model(params: FrozenDict, filename: str) -> None: os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "wb") as f: pickle.dump(params, f) def make_train(config: Config) -> Callable[[jnp.ndarray], TrainOutput]: num_updates = config.training.total_timesteps // config.training.num_steps // config.training.num_envs minibatch_size = config.training.num_envs * config.training.num_steps // config.training.num_minibatches env = HumanoidEnv(config) def linear_schedule(count: int) -> float: # Linear learning rate annealing frac = 1.0 - (count // (minibatch_size * config.training.update_epochs)) / num_updates return config.training.lr * frac def train(rng: jnp.ndarray) -> TrainOutput: network = ActorCritic( num_layers=config.model.num_layers, hidden_size=config.model.hidden_size, action_dim=env.action_size, use_tanh=config.model.use_tanh, ) rng, _rng = jax.random.split(rng) init_x = jnp.zeros(env.observation_size) network_params = network.init(_rng, init_x) # Set up optimizer with gradient clipping and optional learning rate annealing if config.training.anneal_lr: tx = optax.chain( optax.clip_by_global_norm(config.opt.max_grad_norm), optax.adam(learning_rate=linear_schedule, eps=1e-5), ) else: tx = optax.chain( optax.clip_by_global_norm(config.opt.max_grad_norm), optax.adam(config.opt.lr, eps=1e-5), ) train_state = TrainState.create( apply_fn=network.apply, params=network_params, tx=tx, ) # JIT-compile environment functions for performance @jax.jit def reset_fn(rng: jnp.ndarray) -> State: rngs = jax.random.split(rng, config.training.num_envs) return jax.vmap(env.reset)(rngs) @jax.jit def step_fn(states: State, actions: jnp.ndarray, rng: jnp.ndarray) -> State: return jax.vmap(env.step)(states, actions, rng) rng, reset_rng = jax.random.split(rng) env_state = reset_fn(jnp.array(reset_rng)) obs = env_state.obs def _update_step( runner_state: RunnerState, unused: Memory, ) -> tuple[RunnerState, Any]: def _env_step( runner_state: RunnerState, unused: Memory, ) -> tuple[RunnerState, Memory]: train_state, env_state, last_obs, rng = runner_state # Sample actions from the policy and evaluate the value function pi, value = network.apply(train_state.params, last_obs) rng, action_rng = jax.random.split(rng) action = pi.sample(seed=action_rng) log_prob = pi.log_prob(action) # Step the environment rng, step_rng = jax.random.split(rng) step_rngs = jax.random.split(step_rng, config.training.num_envs) env_state: State = step_fn(env_state, action, step_rngs) obs = env_state.obs reward = env_state.reward done = env_state.done info = env_state.metrics # Store experience for later use in PPO updates memory = Memory(done, action, value, reward, log_prob, last_obs, info) runner_state = RunnerState(train_state, env_state, obs, rng) return runner_state, memory # Collect experience for multiple steps runner_state, mem_batch = jax.lax.scan(_env_step, runner_state, None, config.rl.num_env_steps) # Calculate advantages using Generalized Advantage Estimation (GAE) _, last_val = network.apply(runner_state.train_state.params, runner_state.last_obs) last_val = jnp.array(last_val) def _calculate_gae(mem_batch: Memory, last_val: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: def _get_advantages( gae_and_next_value: tuple[jnp.ndarray, jnp.ndarray], memory: Memory ) -> tuple[tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: gae, next_value = gae_and_next_value done, value, reward = memory.done, memory.value, memory.reward # Calculate TD error and GAE delta = reward + config.rl.gamma * next_value * (1 - done) - value gae = delta + config.rl.gamma * config.rl.gae_lambda * (1 - done) * gae return (gae, value), gae # Reverse-order scan to efficiently compute GAE _, advantages = jax.lax.scan( _get_advantages, (jnp.zeros_like(last_val), last_val), mem_batch, reverse=True, unroll=16, ) return advantages, advantages + mem_batch.value advantages, targets = _calculate_gae(mem_batch, last_val) def _update_epoch( update_state: UpdateState, unused: tuple[jnp.ndarray, jnp.ndarray], ) -> tuple[UpdateState, Any]: def _update_minibatch( train_state: TrainState, batch_info: tuple[Memory, jnp.ndarray, jnp.ndarray] ) -> tuple[TrainState, Any]: mem_batch, advantages, targets = batch_info def _loss_fn( params: FrozenDict, mem_batch: Memory, gae: jnp.ndarray, targets: jnp.ndarray ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: # Recompute values to calculate losses pi, value = network.apply(params, mem_batch.obs) log_prob = pi.log_prob(mem_batch.action) # Compute value function loss value_pred_clipped = mem_batch.value + (value - mem_batch.value).clip( -config.rl.clip_eps, config.rl.clip_eps ) value_losses = jnp.square(value - targets) value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() # Compute policy loss using PPO clipped objective ratio = jnp.exp(log_prob - mem_batch.log_prob) gae = (gae - gae.mean()) / (gae.std() + 1e-8) loss_actor1 = ratio * gae loss_actor2 = jnp.clip(ratio, 1.0 - config.rl.clip_eps, 1.0 + config.rl.clip_eps) * gae loss_actor = -jnp.minimum(loss_actor1, loss_actor2) loss_actor = loss_actor.mean() entropy = pi.entropy().mean() total_loss = loss_actor + config.rl.vf_coef * value_loss - config.rl.ent_coef * entropy return total_loss, (value_loss, loss_actor, entropy) # Compute gradients and update model grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) total_loss, grads = grad_fn(train_state.params, mem_batch, advantages, targets) train_state = train_state.apply_gradients(grads=grads) return train_state, total_loss train_state, mem_batch, advantages, targets, rng = update_state rng, _rng = jax.random.split(rng) batch_size = minibatch_size * config.training.num_minibatches if batch_size != config.training.num_steps * config.training.num_envs: raise ValueError("`batch_size` must be equal to `num_steps * num_envs`") # Shuffle and organize data into minibatches permutation = jax.random.permutation(_rng, batch_size) batch = (mem_batch, advantages, targets) batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch) shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch) minibatches = jax.tree_util.tree_map( lambda x: jnp.reshape(x, [config.training.num_minibatches, -1] + list(x.shape[1:])), shuffled_batch, ) # Update model for each minibatch train_state, total_loss = jax.lax.scan(_update_minibatch, train_state, minibatches) update_state = UpdateState(train_state, mem_batch, advantages, targets, rng) return update_state, total_loss # Perform multiple epochs of updates on collected data update_state = UpdateState(runner_state.train_state, mem_batch, advantages, targets, runner_state.rng) update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config.training.update_epochs) runner_state = RunnerState( train_state=update_state.train_state, env_state=runner_state.env_state, last_obs=runner_state.last_obs, rng=update_state.rng, ) return runner_state, mem_batch.info rng, _rng = jax.random.split(rng) runner_state = RunnerState(train_state, env_state, obs, _rng) runner_state, metric = jax.lax.scan(_update_step, runner_state, None, num_updates) return TrainOutput(runner_state=runner_state, metrics=metric) return train def main(args: Sequence[str] | None = None) -> None: logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") if args is None: args = sys.argv[1:] config = load_config_from_cli(args) logger.info("Configuration loaded") rng = jax.random.PRNGKey(config.training.seed) logger.info(f"Random seed set to {config.training.seed}") train_jit = jax.jit(make_train(config)) logger.info("Training function compiled with JAX") logger.info("Starting training...") out = train_jit(rng) logger.info("Training completed") logger.info(f"Saving model to {config.training.model_save_path}") save_model(out.runner_state.train_state.params, config.training.model_save_path) logger.info("Model saved successfully") if __name__ == "__main__": main()
可以看到,jax确实实现了对numpy的融合,并实现了在上面构建深度学习的library,同时在jax基础上也实现了强化学习算法PPO的实现。
可以看到,jax实现的代码确实可以比pytorch实现的PPO算法更加简短,可以说缩短20%的代码量是差不多的。
jax.lax.scan 是可以替代掉python的for循环。
显示的生成随机种子:
jax.random.split(rng)
jax下的flax框架的神经网络定义方式:
可以说,jax实现的代码确实比pytorch实现的代码量要少,但是少的又不是那么多,因为python+pytorch的组合已经可以是算法实现的代码量保持在一个比较低的水平了,而将pytorch换做jax后确实还可以减少20%的代码量,但是这时候减少这些代码量好像又不是那么的必要了,但是,最为严重的问题出现了,那就是Google延用开发TensorFlow时的一个技术思路,那就是把最新的高效编码方式和最新的计算方式融入到jax中,但是这种方式却大大的降低了jax的可读性和易用性。
要知道pytorch是一门发展极为缓慢的框架工具,TensorFlow在诞生之初就实现了原生的分布式计算的支持,而pytorch则要比TensorFlow晚近6到8年时间,同样的事TensorFlow融入各种技术,比如多进程,缓存,多队列,grpc,计算图的预编译,等等,各种各个领域的高性能计算的概率都被融入到TensorFlow中了,但是事实证明TensorFlow融入的这些技术并没有提升TensorFlow的性能,反而导致TensorFlow非常的冗余,有各种重复的API,并且使TensorFlow非常的不方便使用,可以说TensorFlow是给python开发的library,但是TensorFlow用起来就特别像C++,而pytorch则像python一样易用,简直可以说pytorch就是和python一样开盒即用。
可以说jax和当年的TensorFlow有异曲同工之妙,那就是搞的技术都很高大上,但是好像用处不太大,性能没见到提升多少,代码量有一定的减少,但是可读性变的极为的差,而就和TensorFlow中的多种API冗余,多个library的冗余一样,jax上构建的各种框架和包也存在过于宽泛的问题,要知道pytorch和python一样,对于一种功能就提供一种实现,力求最为简单易用,而jax则恨不得搞出各种实现,然后不同的实现后面使用不同的技术底层,但是这样的jax就会变得非常的不好用。
可以说jax和TensorFlow一样,都是技术底层十分强大,但是就是不好用,不易用,代码可读性差,不好上手,估计除了Google的公司内部强制要求使用,再加上Google的相关合作商的强制使用以外真的很难得到大规模的推广和使用了。为啥python现在在很多领域比C++/matlab有市场呢,其原因就是好用,易用,代码可读性高,要知道所谓的一定程度缩短代码量或者一定程度上提高运行效率,和易于使用的这个问题上往往需要这个一定程度是要非常高的,不论是TensorFlow和jax都是无法做到这一点的。
可以说3天时间就可以入门的pytorch,但是TensorFlow可能需要3个月到6个月,而jax或许需要一年的时间,那么这种情况下TensorFlow和jax又怎么能打过pytorch呢。
posted on 2024-12-05 12:05 Angry_Panda 阅读(228) 评论(0) 编辑 收藏 举报
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 周边上新:园子的第一款马克杯温暖上架
· 分享 3 个 .NET 开源的文件压缩处理库,助力快速实现文件压缩解压功能!
· Ollama——大语言模型本地部署的极速利器
· DeepSeek如何颠覆传统软件测试?测试工程师会被淘汰吗?
· 使用C#创建一个MCP客户端
2021-12-05 mini_imagenet 数据集生成工具 (续)