tensorflow2.0——过拟合优化regularization(简化参数结构,添加参数代价变量)

参数过多会导致模型过于复杂而出现过拟合现象,通过在loss函数添加关于参数个数的代价变量,限制参数个数,来达到减小过拟合的目的

 

以下是loss公式:

 

 

 代码多了一个kernel_regularizer参数

  

 

 

 

复制代码
import tensorflow as tf

def preporocess(x,y):
    x = tf.cast(x,dtype=tf.float32) / 255
    x = tf.reshape(x,(-1,28 *28))                   #   铺平
    x = tf.squeeze(x,axis=0)
    # print('里面x.shape:',x.shape)
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

def main():
    #   加载手写数字数据
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    #   处理数据
    #   训练数据
    db = tf.data.Dataset.from_tensor_slices((train_x, train_y))  # 将x,y分成一一对应的元组
    db = db.map(preporocess)  # 执行预处理函数
    db = db.shuffle(60000).batch(2000)  # 打乱加分组
    #   测试数据
    db_test = tf.data.Dataset.from_tensor_slices((test_x, test_y))
    db_test = db_test.map(preporocess)
    db_test = db_test.shuffle(10000).batch(10000)
    #   设置超参
    iter_num = 2000  # 迭代次数
    lr = 0.01  # 学习率
    #   定义模型器和优化器
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),       #   kernel_regularizer是loss上加了关于参数的损失变量
        tf.keras.layers.Dense(128, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(64, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(32, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        tf.keras.layers.Dense(10)
    ])
    #   优化器
    # optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)              #   定义优化器
    model.compile(optimizer= optimizer,loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])       #   定义模型配置
    model.fit(db,epochs=30,validation_data=db,validation_freq=2)          #  运行模型,参数validation_data是指在哪个测试集上进行测试
    model.evaluate(db_test)                                                     #   最后打印测试数据相关准确率数据

if __name__ == '__main__':
    main()
复制代码

 

posted @   山…隹  阅读(724)  评论(1编辑  收藏  举报
编辑推荐:
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
点击右上角即可分享
微信分享提示