flax 01 基本用法

安装jax jaxlib

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

安装flax

pip install flax
pip install --upgrade git+https://github.com/google/flax.git #但是这个我没有成功

文档

文档地址https://flax.readthedocs.io/en/latest/index.html
flax莫的参数和初始化,看两个模型中的代码

class TokenLearnerModule(nn.Module):
  """TokenLearner module.

  This is the module used for the experiments in the paper.

  Attributes:
    num_tokens: Number of tokens.
  """
  num_tokens: int
  use_sum_pooling: bool = True

  @nn.compact
  def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
    """Applies learnable tokenization to the 2D inputs.

    Args:
      inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`.

    Returns:
      Output of shape `[bs, n_token, c]`.
    """
    if inputs.ndim == 3:
      n, hw, c = inputs.shape
      h = int(math.sqrt(hw))
      inputs = jnp.reshape(inputs, [n, h, h, c])#保证形状时这个样子的

      if h * h != hw:
        raise ValueError('Only square inputs supported.')

    feature_shape = inputs.shape

    selected = inputs
    selected = nn.LayerNorm()(selected)

    for _ in range(3):#这里就是向前传报了
      selected = nn.Conv(
          self.num_tokens,
          kernel_size=(3, 3),
          strides=(1, 1),
          padding='SAME',
          use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

      selected = nn.gelu(selected)

    selected = nn.Conv(
        self.num_tokens,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding='SAME',
        use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

    selected = jnp.reshape(
        selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
                  ])  # Shape: [bs, h*w, n_token].
    selected = jnp.transpose(selected, [0, 2, 1])  # Shape: [bs, n_token, h*w].
    selected = nn.sigmoid(selected)[..., None]  # Shape: [bs, n_token, h*w, 1].

    feat = inputs
    feat = jnp.reshape(
        feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
              ])[:, None, ...]  # Shape: [bs, 1, h*w, c].

    if self.use_sum_pooling:
      inputs = jnp.sum(feat * selected, axis=2)
    else:
      inputs = jnp.mean(feat * selected, axis=2)

    return inputs
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

这里包含了几个jax的知识点,但是jax个人也不是很熟悉,所以再去查找jax的文档https://jax.readthedocs.io/en/latest/index.html
找到random的相关内容,PRNGKey是seudorandom number generators keys 方法的缩写,把他认为是生成两个随机的数就可以了,需要注意的是他一次返回两个数值,比如key=random.PRNGKey(0)的返回值是[0,0],这个东西是一个随机的key,好像在jax中没有像常见的那种随机数就是直接给个数,jax的随机数都是要提供一个key的,这个key就是用这个方法所生成的,此时就可以用random.uniform(key)来得到一个服从均匀分布的数字。
同样的所有的随机数都需要这样的一个key,但是不需要重复的进行调用random.PRNGKey,可以使用jax.random.split(key,num=2)来吧这个随机键(暂且那么叫)拆分成更多的子健,每一个子健都可以像原来的那样使用,需要的子健的数量在num参数中给出,此时接受数据的方法就和元组类似k1,k2,k3 = jax.random.split(key,num=3)

参数

参数需要进行初始化,对于习惯了pytorch中的再init中先写模型的定义再向前传播这个无疑是很让人看不懂的,在文档中已经写明了,Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
具体的参数矩阵的形状是交给模型去自动推理的,自己不需要计算,需要提供一个假输入(假输出),模型会自动推算模型的各个矩阵的形状

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input这里就是假定的输入
params = model.init(key2, x) # Initialization call  自动计算参数的大小
jax.tree_map(lambda x: x.shape, params) # Checking output shapes  和python原生的map类似这里的作用主要是查看形状

model.init_with_output就是用输出去计算参数的形状的,需要注意的是,模型的结构并不是存储在model中,而是在params中,这一点和torch非常不同,这也导致了后续在恢复模型参数的时候的不同,需要特别注意。

向前传播

向前传播也和torch有很大的不同,model.apply(params, x)是jax的向前传播语句

向后传播

对于样本 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),目标是找的最优的参数\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\)使得输出在最小二乘法的损失下有最小值。

准备数据

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

使用jax的向前传播

# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

梯度下降

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

完整的可执行代码

import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn


model = nn.Dense(features=5)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
# jax.tree_map(lambda x: x.shape, params)
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)
posted @ 2022-05-25 19:13  hoNoSayaka  阅读(560)  评论(0编辑  收藏  举报