3-2中阶API示范eat_tensorflow2_in_30_days
3-2中阶API示范
TensorFlow的中阶API主要包括各种模型层,损失函数,优化器,数据管道,特征列等等.
import tensorflow as tf
from tensorflow.keras import layers,losses,metrics,optimizers
#打印时间分割线
@tf.function
def printbar():
ts = tf.timestamp()
today_ts = ts%(24*60*60)
hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
minite = tf.cast((today_ts%3600)//60,tf.int32)
second = tf.cast(tf.floor(today_ts%60),tf.int32)
def timeformat(m):
if tf.strings.length(tf.strings.format("{}",m))==1:
return(tf.strings.format("0{}",m))
else:
return(tf.strings.format("{}",m))
timestring = tf.strings.join([timeformat(hour),timeformat(minite),
timeformat(second)],separator = ":")
tf.print("=========="*8, end = "")
tf.print(timestring)
# 样本数量
n = 800
# 生成测试用数据集
X = tf.random.uniform([n, 2], minval=10, maxval=10)
w0 = tf.constant([[2.0], [-1.0]])
b0 = tf.constant(3.0)
Y = X@w0 + b0 + tf.random.normal([n, 1], mean=0.0, stddev=2.0) # @表示矩阵乘法,增加正态扰动
# 构建输入数据管道
ds = tf.data.Dataset.from_tensor_slices((X, Y)) \
.shuffle(buffer_size=1000) \
.batch(100) \
.prefetch(tf.data.experimental.AUTOTUNE) # 表示tf.data模块运行时,框架会根据可用的CPU自动设置最大的可用线程数,以使用多线程进行数据通道处理,将机器的算力拉满。注意返回的变量其实是个常量,表示可用的线程数目。
# 定义优化器
optimizer = optimizers.SGD(learning_rate=0.001)
linear = layers.Dense(units=1)
linear.build(input_shape=(2,)) # 初始化模型
@tf.function
def train(epoches):
for epoch in tf.range(1, epoches+1):
L = tf.constant(0.0) # 使用L记录loss值
for X_batch, Y_batch in ds:
with tf.GradientTape() as tape:
Y_hat = linear(X_batch)
loss = losses.mean_squared_error(tf.reshape(Y_hat, [-1]), tf.reshape(Y_batch, [-1]))
grads = tape.gradient(loss, linear.variables)
optimizer.apply_gradients(zip(grads, linear.variables))
L = loss
if (epoch % 100 == 0):
printbar()
tf.print("epoch=", epoch, "loss=", L)
tf.print("w=", linear.kernel)
tf.print("b=", linear.bias)
tf.print("")
train(500)
"""
================================================================================14:08:41
epoch= 100 loss= 4.19749069
w= [[1.98472738]
[-1.00138092]]
b= [2.3219955]
================================================================================14:08:43
epoch= 200 loss= 4.11556864
w= [[1.98508048]
[-0.999055743]]
b= [2.78728962]
================================================================================14:08:46
epoch= 300 loss= 3.96286321
w= [[1.98485506]
[-0.999063373]]
b= [2.88118291]
================================================================================14:08:48
epoch= 400 loss= 3.92058372
w= [[1.9838568]
[-0.998188376]]
b= [2.90016508]
================================================================================14:08:50
epoch= 500 loss= 3.35933471
w= [[1.98283279]
[-0.997389]]
b= [2.90403414]
"""