使用贝叶斯优化进行深度神经网络超参数优化

在本文中,我们将深入研究超参数优化。

为了方便起见本文将使用 Tensorflow 中包含的 Fashion MNIST[1] 数据集。该数据集在训练集中包含 60,000 张灰度图像,在测试集中包含 10,000 张图像。每张图片代表属于 10 个类别之一的单品(“T 恤/上衣”、“裤子”、“套头衫”等)。因此这是一个多类分类问题。

这里简单介绍准备数据集的步骤,因为本文的主要内容是超参数的优化,所以这部分只是简单介绍流程,一般情况下,流程如下:

  • 加载数据。
  • 分为训练集、验证集和测试集。
  • 将像素值从 0–255 标准化到 0–1 范围。
  • One-hot 编码目标变量。
  1. #load data
  2. (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
  3. # split into train, validation and test sets
  4. train_x, val_x, train_y, val_y = train_test_split(train_images, train_labels, stratify=train_labels, random_state=48, test_size=0.05)
  5. (test_x, test_y)=(test_images, test_labels)
  6. # normalize pixels to range 0-1
  7. train_x = train_x / 255.0
  8. val_x = val_x / 255.0
  9. test_x = test_x / 255.0
  10. #one-hot encode target variable
  11. train_y = to_categorical(train_y)
  12. val_y = to_categorical(val_y)
  13. test_y = to_categorical(test_y)

我们所有训练、验证和测试集的形状是:

  1. print(train_x.shape) #(57000, 28, 28)
  2. print(train_y.shape) #(57000, 10)
  3. print(val_x.shape) #(3000, 28, 28)
  4. print(val_y.shape) #(3000, 10)
  5. print(test_x.shape) #(10000, 28, 28)
  6. print(test_y.shape) #(10000, 10)

现在,我们将使用 Keras Tuner 库 [2]:它将帮助我们轻松调整神经网络的超参数:

  1. pip install keras-tuner

Keras Tuner 需要 Python 3.6+ 和 TensorFlow 2.0+

超参数调整是机器学习项目的基础部分。有两种类型的超参数:

  • 结构超参数:定义模型的整体架构(例如隐藏单元的数量、层数)
  • 优化器超参数:影响训练速度和质量的参数(例如学习率和优化器类型、批量大小、轮次数等)

为什么需要超参数调优库?我们不能尝试所有可能的组合,看看验证集上什么是最好的吗?

这肯定是不行的因为深度神经网络需要大量时间来训练,甚至几天。如果在云服务器上训练大型模型,那么每个实验实验都需要花很多的钱。

因此,需要一种限制超参数搜索空间的剪枝策略。

keras-tuner提供了贝叶斯优化器。它搜索每个可能的组合,而是随机选择前几个。然后根据这些超参数的性能,选择下一个可能的最佳值。因此每个超参数的选择都取决于之前的尝试。根据历史记录选择下一组超参数并评估性能,直到找到最佳组合或到达最大试验次数。我们可以使用参数“max_trials”来配置它。

除了贝叶斯优化器之外,keras-tuner还提供了另外两个常见的方法:RandomSearch 和 Hyperband。我们将在本文末尾讨论它们。

接下来就是对我们的网络应用超参数调整。我们尝试两种网络架构,标准多层感知器(MLP)和卷积神经网络(CNN)。

首先让我们看看基线 MLP 模型是什么:

  1. model_mlp = Sequential()
  2. model_mlp.add(Flatten(input_shape=(28, 28)))
  3. model_mlp.add(Dense(350, activation='relu'))
  4. model_mlp.add(Dense(10, activation='softmax'))
  5. print(model_mlp.summary())
  6. model_mlp.compile(optimizer="adam",loss='categorical_crossentropy')

调优过程需要两种主要方法:

hp.Int():设置超参数的范围,其值为整数 - 例如,密集层中隐藏单元的数量:

  1. model.add(Dense(units = hp.Int('dense-bot', min_value=50, max_value=350, step=50))

hp.Choice():为超参数提供一组值——例如,Adam 或 SGD 作为最佳优化器?

  1. hp_optimizer=hp.Choice('Optimizer', values=['Adam', 'SGD'])

在我们的 MLP 示例中,我们测试了以下超参数:

  • 隐藏层数:1-3
  • 第一密集层大小:50–350
  • 第二和第三密集层大小:50–350
  • Dropout:0、0.1、0.2
  • 优化器:SGD(nesterov=True,momentum=0.9) 或 Adam
  • 学习率:0.1、0.01、0.001

代码如下:

  1. model = Sequential()
  2. model.add(Dense(units = hp.Int('dense-bot', min_value=50, max_value=350, step=50), input_shape=(784,), activation='relu'))
  3. for i in range(hp.Int('num_dense_layers', 1, 2)):
  4. model.add(Dense(units=hp.Int('dense_' + str(i), min_value=50, max_value=100, step=25), activation='relu'))
  5. model.add(Dropout(hp.Choice('dropout_'+ str(i), values=[0.0, 0.1, 0.2])))
  6. model.add(Dense(10,activation="softmax"))
  7. hp_optimizer=hp.Choice('Optimizer', values=['Adam', 'SGD'])
  8. if hp_optimizer == 'Adam':
  9. hp_learning_rate = hp.Choice('learning_rate', values=[1e-1, 1e-2, 1e-3])
  10. elif hp_optimizer == 'SGD':
  11. hp_learning_rate = hp.Choice('learning_rate', values=[1e-1, 1e-2, 1e-3])
  12. nesterov=True
  13. momentum=0.9

这里需要注意第 5 行的 for 循环:让模型决定网络的深度!

最后,就是运行了。请注意我们之前提到的 max_trials 参数。

  1. model.compile(optimizer = hp_optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
  2. tuner_mlp = kt.tuners.BayesianOptimization(
  3. model,
  4. seed=random_seed,
  5. objective='val_loss',
  6. max_trials=30,
  7. directory='.',
  8. project_name='tuning-mlp')
  9. tuner_mlp.search(train_x, train_y, epochs=50, batch_size=32, validation_data=(dev_x, dev_y), callbacks=callback)

我们得到结果

这个过程用尽了迭代次数,大约需要 1 小时才能完成。我们还可以使用以下命令打印模型的最佳超参数:

  1. best_mlp_hyperparameters = tuner_mlp.get_best_hyperparameters(1)[0]
  2. print("Best Hyper-parameters")
  3. best_mlp_hyperparameters.values

现在我们可以使用最优超参数重新训练我们的模型:

完整文章

https://avoid.overfit.cn/post/c3f904fab4f84914b8a1935f8670582f

posted @   deephub  阅读(465)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示