flax 03 创建自己的模型
模型创建
和torch一样,只有继承了nn.Module的才是可以使用的模型,第一步就是导入包import flax.linen as nn
,接下来会举出一个有非线性激活函数的多层感知机作为例子进行学习。
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# we automatically know what to do with lists, dicts of submodules
self.layers = [nn.Dense(feat) for feat in self.features]
# for single submodules, we would just write:
# self.layer1 = nn.Dense(feat1)
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
需要注意的是,features: Sequence[int]
这句话是需要提前引入from typing import Sequence
才可使用,这个features保存了定义模型的中间层的形状,记住是中间层,输入层是使用model.init
去自动推理的。这里的setup应该和torch中的__init__有些类似,用来定义模型,但是不知道为何在01中给出的那个例子中没有使用这个方法。这里定义多层的方式比torch要容易理解,torch必须使用自己的列表类型,而jax中可以直接使用最简单的列表,同时在__call__中写明向前传播的方式就完成了模型的定义。
下面放一下原文档:
As we can see, a nn.Module
subclass is made of:
- A collection of data fields (
nn.Module
are Python dataclasses) - here we only have thefeatures
field of typeSequence[int]
. - A
setup()
method that is being called at the end of the__postinit__
where you can register submodules, variables, parameters you will need in your model. - A
__call__
function that returns the output of the model from a given input. - The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one
layers_n
sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.
由于模型的参数和模型本身不是绑定的,所以即使有call方法,也是不可以使用model(x)
来进行向前传播的,必须使用apply进行向前传播。
模块内联
看到这里我就基本上看懂为什么01中的那个模型没有使用setup来初始化模型了,使用内联可以在向前传播的同时进行模块的声明。这里再吧那个代码放一下:
class TokenLearnerModule(nn.Module):
"""TokenLearner module.
This is the module used for the experiments in the paper.
Attributes:
num_tokens: Number of tokens.
"""
num_tokens: int
use_sum_pooling: bool = True
@nn.compact
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
if inputs.ndim == 3:
n, hw, c = inputs.shape
h = int(math.sqrt(hw))
inputs = jnp.reshape(inputs, [n, h, h, c])#保证形状是这个样子的
if h * h != hw:
raise ValueError('Only square inputs supported.')
feature_shape = inputs.shape
selected = inputs
selected = nn.LayerNorm()(selected)
for _ in range(3):#这里就是向前传播了
selected = nn.Conv(
self.num_tokens,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=False)(selected) # Shape: [bs, h, w, n_token].
selected = nn.gelu(selected)
selected = nn.Conv(
self.num_tokens,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=False)(selected) # Shape: [bs, h, w, n_token].
selected = jnp.reshape(
selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
]) # Shape: [bs, h*w, n_token].
selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].
selected = nn.sigmoid(selected)[..., None] # Shape: [bs, n_token, h*w, 1].
feat = inputs
feat = jnp.reshape(
feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
])[:, None, ...] # Shape: [bs, 1, h*w, c].
if self.use_sum_pooling:
inputs = jnp.sum(feat * selected, axis=2)
else:
inputs = jnp.mean(feat * selected, axis=2)
return inputs
也正是因为有这种一起的写法,所以必须吧模型的参数和设计分开来,如果用torch的想法来的话,每一次执行
selected = nn.Conv(
self.num_tokens,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=False)(selected)
都会创建一个新的层去进行传播,但是将模型和参数分离的话,这里只负责结构和传播算法,真实的参数是另外保存的,就不会出现这个问题。
这两种方法需要注意的是:
- 在
setup
中,可以命名一些子层并将它们保留以供进一步使用(例如,自动编码器中的编码器/解码器方法)。 - 如果你想拥有多个方法,那么你需要使用
setup
声明模块,因为@nn.compact
注解只允许注解一个方法。
说白了就是,再一次传播中只用一次的话,可以使用@nn.compact
,但是如果在后续层要用到前面层的话,就只能使用setup
,或者是多个层叠的结构,比如定义好的编码器和解码器,就只能使用setup
。 - 他们初始化的方法不同,但是如何不同没有写
使用@nn.compact
的一个最大的好处可以让模型变得异常清楚简练,这里用个卷积网络来举例
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
是不是非常的简洁易懂,在定义的同时进行了向前传播的定义,看到这个例子我真的是当时就喜欢上了这个框架,但是在进行复杂的各种传播时这个是不能那么写的,不知道对于残差的连接应该怎么实现,不知道会不会有例子
模型参数
在之前的写法中,我们使用了框架自带的Dense
来实现全连接,如果没有提供,如何实现自己的全连接呢?接下来就是使用@nn.compact
来实现自己的全连接
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features)) # shape info.
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
点击查看输出
initialized parameters:
FrozenDict({
params: {
kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],
[ 0.05673932, 0.9909285 , -0.63536596],
[ 0.76134115, -0.3250529 , -0.65221626],
[-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),
bias: DeviceArray([0., 0., 0.], dtype=float32),
},
})
output:
[[ 0.5035518 1.8548558 -0.4270195 ]
[ 0.0279097 0.5589246 -0.43061772]
[ 0.3547128 1.5740999 -0.32865518]
[ 0.5264864 1.2928858 0.10089308]]
目测的话应该是调用父类的param方法来生成可训练的参数形状,然后在后面的model.init(key2, x)
中生成参数。但是这个y = lax.dot_general(inputs, kernel,(((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
是真的没有看懂,甚至在原文中都标注# TODO Why not jnp.dot?
其实前面几个参数还好,就是后面传入的几个空的元组是真的没有看懂到底是怎么用的。
这个self.param
需要三个参数(name, init_fn, *init_args)
name
是这个参数的名字init_fn
是个初始化的方法,照着写就行了,这个函数需要两个参数(PRNGKey, *init_args)
init_args
是提供给初始化的参数也就是init_fn
所需要的参数,目测就传入需要的形状即可
这些步骤都是可以在setup中进行的,但是同样的是不会生成参数的,需要使用init
方法
参数和参数容器
这一部分主要介绍使用变量方法声明模型参数之外的变量。这里实现了一个类似于batrchnorm的方法,但是只是举个例子,在实际应用中应该使用flax已经实现的进行而不是自己按着这个写一个。
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# easy pattern to detect if we're initializing via empty variable tree
is_initialized = self.has_variable('batch_stats', 'mean')
ra_mean = self.variable('batch_stats', 'mean',
lambda s: jnp.zeros(s),
x.shape[1:])
mean = ra_mean.value # This will either get the value or trigger init
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
return x - ra_mean.value + bias
key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)