摘要: 这里总体的看一个MNIST的例子用来看看flax是如何工作的 导包 import jax import jax.numpy as jnp # JAX NumPy from flax import linen as nn # The Linen API from flax.training impor 阅读全文
posted @ 2022-05-26 16:20 hoNoSayaka 阅读(267) 评论(0) 推荐(0) 编辑
摘要: 模型创建 和torch一样,只有继承了nn.Module的才是可以使用的模型,第一步就是导入包import flax.linen as nn,接下来会举出一个有非线性激活函数的多层感知机作为例子进行学习。 class ExplicitMLP(nn.Module): features: Sequenc 阅读全文
posted @ 2022-05-26 14:26 hoNoSayaka 阅读(887) 评论(0) 推荐(1) 编辑
摘要: 在上一个文章中简要的写了一下jax是怎么用的,但是其中的优化是自己实现的,这一章讲一下优化器的使用方法 flax.optim 和torch相似,与优化相关的方法存在了这个包里面。但是好像在jax中有更高的替代方法,所以这个方法已经被deprecated了。代替的方法是Optax[https://gi 阅读全文
posted @ 2022-05-26 09:36 hoNoSayaka 阅读(376) 评论(0) 推荐(0) 编辑
摘要: 安装jax jaxlib pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. # Note: wheels only available on linux. pi 阅读全文
posted @ 2022-05-25 19:13 hoNoSayaka 阅读(574) 评论(0) 推荐(0) 编辑
摘要: 讲述了用平常的训练方法,深度比惨大的网络表现应该优于浅层的网络,因为CNN能够提取low/mid/high-level的特征,网络的层数越多,意味着能够提取到不同level的特征越丰富。并且,越深的网络提取的特征越抽象,越具有语义信息。 但是本文的实验下来表明并不是这样,由此引出了‘退化问题’,训练 阅读全文
posted @ 2021-06-15 09:34 hoNoSayaka 阅读(53) 评论(0) 推荐(0) 编辑