Loading

3-3高阶API示范——eat_tensorflow2_in_30_days

3-3高阶API示范

Tensorflow的高阶API主要为 tf.keras.models提供的模型的类接口

使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型

使用Sequential按层顺序构建模型

import tensorflow as tf
from tensorflow.keras import models, layers, optimizers

# 样本数量
n = 800

# 生成测试用数据集
X = tf.random.uniform([n, 2], minval=-10, maxval=10)
w0 = tf.constant([[2.0], [-1.0]])
b0 = tf.constant(3.0)

Y = X@w0 + b0 + tf.random.normal([n, 1], mean=0.0, stddev=2.0)  # @表示矩阵乘法,增加正态扰动
tf.keras.backend.clear_session()

linear = models.Sequential()
linear.add(layers.Dense(1, input_shape=(2,)))
linear.summary()

"""
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1)                 3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
"""
# 使用fit方法进行训练
linear.compile(optimizer='adam', loss='mse', metrics=['mae'])
linear.fit(X, Y, batch_size=20, epochs=200)

tf.print("w=", linear.layers[0].kernel)
tf.print("b=", linear.layers[0].bias)

继承Model基类构建自定义模型

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers,losses,metrics


#打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
    
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
    
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)
# 样本数量
n = 800

# 生成测试用数据集
X = tf.random.uniform([n,2], minval=-10, maxval=10)
w0 = tf.constant([[2.0], [-1.0]])
b0 = tf.constant(3.0)

Y = X@w0 + b0 + tf.random.normal([n, 1], mean=0.0, stddev=2.0)  # @表示矩阵乘法,增加正态扰动

ds_train = tf.data.Dataset.from_tensor_slices((X[0:n*3//4, :], Y[0:n*3//4, :])) \
            .shuffle(buffer_size=1000).batch(20) \
            .prefetch(tf.data.experimental.AUTOTUNE).cache()

ds_valid = tf.data.Dataset.from_tensor_slices((X[n*3//4:, :], Y[n*3//4:, :])) \
            .shuffle(buffer_size=1000).batch(20) \
            .prefetch(tf.data.experimental.AUTOTUNE).cache()
tf.keras.backend.clear_session()

class Mymodel(models.Model):
    def __init__(self):
        super(Mymodel, self).__init__()
        
    def build(self, input_shape):
        self.dense1 = layers.Dense(1)
        super(Mymodel, self).build(input_shape)
        
    def call(self, x):
        y = self.dense1(x)
        return y
    
model = Mymodel()
model.build(input_shape=(None, 2))
model.summary()

"""
Model: "mymodel"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
"""
# 自定义训练循环
optimizer = optimizers.Adam()
loss_func = losses.MeanSquaredError()

train_loss = tf.keras.metrics.Mean(name="train_loss")
train_metric = tf.keras.metrics.MeanAbsoluteError(name="train_mae")

valid_loss = tf.keras.metrics.Mean(name="valid_loss")
valid_metric = tf.keras.metrics.MeanAbsoluteError(name="valid_mae")

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)
    
@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
    
@tf.function
def train_model(model, ds_train, ds_valid, epochs):
    for epoch in tf.range(1, epochs+1):
        for features, labels in ds_train:
            train_step(model, features, labels)
        
        for features, labels in ds_valid:
            valid_step(model, features, labels)
            
        logs =  'Epoch={}, Loss:{},MAE:{},Valid Loss:{},Valid MAE:{}'
        
        if epoch % 100 == 0:
            printbar()
            tf.print(tf.strings.format(logs, (epoch, train_loss.result(), train_metric.result(), valid_loss.result(), valid_metric.result())))
            tf.print("w=", model.layers[0].kernel)
            tf.print("b=", model.layers[0].bias)
            tf.print("")
        
        train_loss.reset_state()
        train_metric.reset_state()
        valid_loss.reset_state()
        valid_metric.reset_state()
        
train_model(model, ds_train, ds_valid, 400)

"""
================================================================================15:11:22
Epoch=100, Loss:9.93945,MAE:2.55268478,Valid Loss:8.8473568,Valid MAE:2.35548472
w= [[1.66429639]
 [-0.957504869]]
b= [1.9203496]

================================================================================15:11:29
Epoch=200, Loss:4.17593479,MAE:1.6221776,Valid Loss:4.06166744,Valid MAE:1.62463975
w= [[2.01739311]
 [-0.978583097]]
b= [2.87123]

================================================================================15:11:35
Epoch=300, Loss:4.17040348,MAE:1.62676454,Valid Loss:4.10403919,Valid MAE:1.63572729
w= [[2.0172267]
 [-0.978128374]]
b= [2.94880891]

================================================================================15:11:41
Epoch=400, Loss:4.17061424,MAE:1.6268065,Valid Loss:4.10414696,Valid MAE:1.6357733
w= [[2.017344]
 [-0.978120208]]
b= [2.94895935]
"""
posted @ 2022-06-19 15:14  lotuslaw  阅读(30)  评论(0编辑  收藏  举报