22、自定义网络、模型保存与加载
- keras.Sequential 容器
- keras.layers.Layers
- keras.Model
1、keras.Sequential 容器
(1)如果不使用容器,就需要写很多行代码,而且需要关注每一个layers的参数,容器很方便的就能创建一个网络
1 model = Sequential([ 2 layers.Dense(256, activation=tf.nn.relu), # 256是该层的输入,[b, 784] => [b, 256] 3 layers.Dense(128, activation=tf.nn.relu), #[b, 256] => [b, 128] 4 layers.Dense(64, activation=tf.nn.relu), #[b, 128] => [b, 64] 5 layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32] 6 layers.Dense(10)]) # [b, 32] => [b, 10], 330 = 32*10 + 10 最后一层需要激活函数 7 8 # dense表示全连接层 9 model.build(input_shape=[None,28*28]) #对第一层进行输入,权值的构建需要输入的维度,build会自动创建w和b参数 10 model.summary() #调试的功能,打印网络结构
(2)model.trainable_variables 和 model.call()
得到一个列表[w1,b1,w2,b2,......wn,bn],每一个i是一个dense层,通过Sequential容器将原来分散在各个layer层的参数集中起来,在自定义的时候,每做一次forward的时候,我们需要将输入的x传给dense1—>dense2—>dense3......dense5,而这样的一个重复的逻辑可以直接调用一次Sequential就能实现,它会在内部调用call方法,一般来讲,我们调用一个network的forward方法是通过python中的model__call__方法实现,在自定义网络中,需要自己实现call方法,而在call方法中,我们需要继承__call__方法。
2、keras.layers.Layer / keras.Model
- 继承 keras.layers.Layer 和 keras.Model, 这两个类是所有自定义层的母类
- _init_ , 初始化函数,在初始化函数里面调用一个父类的初始化方法,
- call, 自己的逻辑在call方法中,对一个类的调用,直接使用这个类的model(x),这个model(x)会调用model.__call__(x),model.__call__(x)会调用母类的call方法
- Model: compile/fit/evaluate, 在model这个类中还有其他的接口
(1)创建Dense
class MyDense(layers.Layer): # to replace standard layers.Dense() def __init__(self, inp_dim, outp_dim): super(MyDense, self).__init__() # 创建时不能创建一个常量,即constant self.kernel = self.add_variable('w', [inp_dim, outp_dim]) #可自定义的点 self.bias = self.add_variable('b', [outp_dim]) #设置w和b的维度 def call(self, inputs, training=None): x = inputs @ self.kernel return x
(2)创建网络
我们可以在call中进行一些额外的操作,如对x进行加1等
1 class MyModel(keras.Model): 2 3 def __init__(self): 4 super(MyModel, self).__init__() 5 6 self.fc1 = MyDense(28*28, 256) 7 self.fc2 = MyDense(256, 128) 8 self.fc3 = MyDense(128, 64) 9 self.fc4 = MyDense(64, 32) 10 self.fc5 = MyDense(32, 10) 11 12 def call(self, inputs, training=None): 13 14 x = self.fc1(inputs) 15 x = tf.nn.relu(x) 16 x = self.fc2(x) 17 x = tf.nn.relu(x) 18 x = self.fc3(x) 19 x = tf.nn.relu(x) 20 x = self.fc4(x) 21 x = tf.nn.relu(x) 22 x = self.fc5(x) 23 return x
3、模型的保存和加载
- save/load weights 最轻量级的保存方式,只保存模型的权值参数
- save/load entire model 将模型的所有状态进行保存
- saved_model 模型的一种保存格式,如果训练出一个模型之后,要将模型交给生产环境的时候,只需将模型交给用户进行部署,而不需要将模型的源代码交给用户
在深度学习的模型中,一个网络可能会非常的复杂,训练时间较长,因此需要将模型进行保存,在需要的的时候进行加载,使用python进行原型开发,导出模型之后,使用c++完成工业环境的部署