摘要: 这里总体的看一个MNIST的例子用来看看flax是如何工作的 导包 import jax import jax.numpy as jnp # JAX NumPy from flax import linen as nn # The Linen API from flax.training impor 阅读全文
posted @ 2022-05-26 16:20 hoNoSayaka 阅读(228) 评论(0) 推荐(0) 编辑
摘要: 模型创建 和torch一样,只有继承了nn.Module的才是可以使用的模型,第一步就是导入包import flax.linen as nn,接下来会举出一个有非线性激活函数的多层感知机作为例子进行学习。 class ExplicitMLP(nn.Module): features: Sequenc 阅读全文
posted @ 2022-05-26 14:26 hoNoSayaka 阅读(786) 评论(0) 推荐(0) 编辑
摘要: 在上一个文章中简要的写了一下jax是怎么用的,但是其中的优化是自己实现的,这一章讲一下优化器的使用方法 flax.optim 和torch相似,与优化相关的方法存在了这个包里面。但是好像在jax中有更高的替代方法,所以这个方法已经被deprecated了。代替的方法是Optax[https://gi 阅读全文
posted @ 2022-05-26 09:36 hoNoSayaka 阅读(321) 评论(0) 推荐(0) 编辑