(二)tensorflow2.0 - 自定义Model
前文写了如何使用tensorflow2.0自定义Layer,本文将讲述如何自定义Model,并将前述的Layer应用到本Model中来。
(一)tensorflow2.0 - 自定义layer
(二)tensorflow2.0 - 自定义Model
(三)tensorflow2.0 - 自定义loss function(损失函数)
(四)tensorflow2.0 - 实战稀疏自动编码器SAE
自定义模型也比较简单,只是需要搞清楚Model中各部分的作用及执行流程即可。
由于本例中将使用前文中的自定义Layer,因此先将其代码贴过来以便查阅,没看过前文的也没关系,不影响对自定义模型的理解。
import tensorflow as tf
from tensorflow.keras import *
class SAELayer(layers.Layer):
# 初始化num_outputs,即当前层输出元素的个数
def __init__(self, num_outputs):
super(SAELayer, self).__init__()
self.num_outputs = num_outputs
# 在第一次调用该Layer的call方法前(自动)调用该函数,可以知道输入数据的shape
# 根据输入数据的shape可以初始化权值、bias的矩阵
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
self.bias = self.add_variable("bias",
shape=[self.num_outputs])
def call(self, input):
output = tf.matmul(input, self.kernel) + self.bias
# sigmoid激活函数
output = tf.nn.sigmoid(output)
return output
下面自定义模型了,引入的库函数见上面代码的最前面。需要注意,Layer和Model都是类,且都要继承自某些父类。这里继承的是tensorflow.keras.Model
。这里需要实现两个方法,即__init__()
和call()
。__init__()
是在创建类的对象时调用的,可以按需传入一些初始化参数。下例构建的是一个三层模型(输入层由于体现不出来,所以代码里看起来是两层)。
class SAEModel(Model):
def __init__(self, input_shape, output_shape, hidden_shape=None):
# print("init")
# 隐藏层节点个数默认为输入层的3倍
if hidden_shape == None:
hidden_shape = 3 * input_shape
# 调用父类__init__()方法
super(SAEModel, self).__init__()
# 初始化模型使用的layer,layer_1为前述自定义layer
self.layer_1 = SAELayer(hidden_shape)
# layer_2为全连接层,采用sigmoid激活函数
# 每层在这里可以不考虑输入元素个数,但必须考虑输出元素个数
# 输入元素个数可以在call()函数中动态确定
self.layer_2 = layers.Dense(output_shape, activation=tf.nn.sigmoid)
def call(self, input_tensor, training=False):
# 输入数据
hidden = self.layer_1(input_tensor)
output = self.layer_2(hidden)
return output
到此模型就定义完了,然后可以按照一般的流程使用该模型。
下面只是简单的使用模型的例子,只罗列出来,参数没有完善,请按需补充后使用。
input_shape = 5
output_shape = 6
model = SAEModel(input_shape, output_shape)
model.build(input_shape=[None, 5])
model.summary()
model.compile(optimizer=, loss=, metrics=[])
到此自定义Model已经结束了,但是很多时候我们往往需要自定义损失函数,而如果损失函数需要自定义除了预测值和实际值之外的额外参数的话,还需要对model进行修改,这我们将在下一篇文章中讨论。
(三)tensorflow2.0 - 自定义loss function(损失函数)
参考文献: