flax 02 优化器,模型保存
在上一个文章中简要的写了一下jax是怎么用的,但是其中的优化是自己实现的,这一章讲一下优化器的使用方法
flax.optim
和torch相似,与优化相关的方法存在了这个包里面。但是好像在jax中有更高的替代方法,所以这个方法已经被deprecated了。代替的方法是Optax[https://github.com/deepmind/optax] ,为什么不能统一起来啊,真的是。
这个的使用方法在文档中说是非常简单的,那就来一起看一下
- 1、 选择一个优化器,比如
optax.sgd
- 2、从参数中得到优化器状态(可能是设置学习率?)
- 3、使用
jax.value_and_grad()
来计算损失 - 4、在每个迭代上,使用Optax的
update
来对优化器的状态进行更新(类似于troch中的zero_grand?反正就是有一些事物需要处理,不一定是清空梯度),然后使用apply_updates
来进行更新参数
这个需要先计算更新再实施更新的方法有点想django中ORM的变化操作。
需要注意的是,在optax的安装过程中,会出现optax 0.1.2和jax 0.3.10不兼容的情况,这时只需要使用之前同样的安装jax的命令,同时指定jax[cuda]<0.3.7就可以解决这个问题。
基本使用方法
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
可以看到,这里和torch的部分还是比较像的,先初始化一个优化器,但是需要优化的参数和学习率是分开指定的,在torch中是一次指定的,需要优化的参数使用优化器进行包裹,最后的损失函数需要使用jax.value_and_grad(mse)
进行包裹
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
这里就是epoch的训练过程了,训练过程的向前传播是在,损失函数中进行的,损失函数的输出是损失值和梯度。接下来将梯度和参数传给优化器,优化器进行更新
模型保存
在训练后的模型参数是肯定需要保存下来的,保存的方法是使用flax中提供的序列化方法,有两种保存的方式,一是保存成二进制数据,二是保存成字典
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
模型读取
在读取模型时同样也需要先初始化模型的结构,因为上面的序列化只会保存模型的参数。这里有个特别需要注意的地方,模型的结构并不是保存在生成的模型中,而是在参数中,还记得之前的哪个初始化参数的方法吗,就是哪个参数,这个参数不仅仅保存参数的数值,还会保存参数的结构。这个参数会被用来当作读取参数时的模板。serialization.from_bytes(params, bytes_output)
,这个语句就让模型从之前保存的数据(bytes_output)中恢复了参数。