3-2中阶API示范eat_tensorflow2_in_30_days

3-2中阶API示范#

TensorFlow的中阶API主要包括各种模型层,损失函数,优化器,数据管道,特征列等等.

import tensorflow as tf
from tensorflow.keras import layers,losses,metrics,optimizers


#打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
    
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
    
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8, end = "")
    tf.print(timestring)
# 样本数量
n = 800

# 生成测试用数据集
X = tf.random.uniform([n, 2], minval=10, maxval=10)
w0 = tf.constant([[2.0], [-1.0]])
b0 = tf.constant(3.0)
Y = X@w0 + b0 + tf.random.normal([n, 1], mean=0.0, stddev=2.0)  # @表示矩阵乘法,增加正态扰动

# 构建输入数据管道
ds = tf.data.Dataset.from_tensor_slices((X, Y)) \
    .shuffle(buffer_size=1000) \
    .batch(100) \
    .prefetch(tf.data.experimental.AUTOTUNE)  # 表示tf.data模块运行时,框架会根据可用的CPU自动设置最大的可用线程数,以使用多线程进行数据通道处理,将机器的算力拉满。注意返回的变量其实是个常量,表示可用的线程数目。

# 定义优化器
optimizer  = optimizers.SGD(learning_rate=0.001)
linear = layers.Dense(units=1)
linear.build(input_shape=(2,))  # 初始化模型

@tf.function
def train(epoches):
    for epoch in tf.range(1, epoches+1):
        L = tf.constant(0.0)  # 使用L记录loss值
        for X_batch, Y_batch in ds:
            with tf.GradientTape() as tape:
                Y_hat = linear(X_batch)
                loss = losses.mean_squared_error(tf.reshape(Y_hat, [-1]), tf.reshape(Y_batch, [-1]))
            grads = tape.gradient(loss, linear.variables)
            optimizer.apply_gradients(zip(grads, linear.variables))
            L = loss
            
        if (epoch % 100 == 0):
            printbar()
            tf.print("epoch=", epoch, "loss=", L)
            tf.print("w=", linear.kernel)
            tf.print("b=", linear.bias)
            tf.print("")
train(500)

"""
================================================================================14:08:41
epoch= 100 loss= 4.19749069
w= [[1.98472738]
 [-1.00138092]]
b= [2.3219955]

================================================================================14:08:43
epoch= 200 loss= 4.11556864
w= [[1.98508048]
 [-0.999055743]]
b= [2.78728962]

================================================================================14:08:46
epoch= 300 loss= 3.96286321
w= [[1.98485506]
 [-0.999063373]]
b= [2.88118291]

================================================================================14:08:48
epoch= 400 loss= 3.92058372
w= [[1.9838568]
 [-0.998188376]]
b= [2.90016508]

================================================================================14:08:50
epoch= 500 loss= 3.35933471
w= [[1.98283279]
 [-0.997389]]
b= [2.90403414]
"""

作者:lotuslaw

出处:https://www.cnblogs.com/lotuslaw/p/16390439.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   lotuslaw  阅读(20)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
more_horiz
keyboard_arrow_up light_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示