训练集、验证集和测试集的概念及划分原则
深度学习中,常将可得的数据集划分为训练集(training set),验证集(development set/validation set)和测试集(test set).下文主要回答以下几个问题:一是为什么要将数据集划分为如上三个集合,三个集合之间有什么区别;二是我们划分的原则是什么.
1. 训练集、验证集和测试集的概念
-
训练集:顾名思义指的是用于训练的样本集合,主要用来训练神经网络中的参数.
-
验证集:从字面意思理解即为用于验证模型性能的样本集合.不同神经网络在训练集上训练结束后,通过验证集来比较判断各个模型的性能.这里的不同模型主要是指对应不同超参数的神经网络,也可以指完全不同结构的神经网络.
-
测试集:对于训练完成的神经网络,测试集用于客观的评价神经网络的性能.
那么,训练集、验证集和测试集之间又有什么区别呢?一般而言,训练集与后两者之间较易分辨,验证集和测试集之间的概念较易混淆.个人是从下面的角度来理解的:
-
神经网络在网络结构确定的情况下,有两部分影响模型最终的性能,一是普通参数(比如权重w和偏置b),另一个是超参数(例如学习率,网络层数).普通参数我们在训练集上进行训练,超参数我们一般人工指定(比较不同超参数的模型在验证集上的性能).那为什么我们不像普通参数一样在训练集上训练超参数呢?(花书给出了解答)一是超参数一般难以优化(无法像普通参数一样通过梯度下降的方式进行优化).二是超参数很多时候不适合在训练集上进行训练,例如,如果在训练集上训练能控制模型容量的超参数,这些超参数总会被训练成使得模型容量最大的参数(因为模型容量越大,训练误差越小),所以训练集上训练超参数的结果就是模型绝对过拟合.
-
正因为超参数无法在训练集上进行训练,因此我们单独设立了一个验证集,用于选择(人工训练)最优的超参数.因为验证集是用于选择超参数的,因此验证集和训练集是独立不重叠的.
-
测试集是用于在完成神经网络训练过程后,为了客观评价模型在其未见过(未曾影响普通参数和超参数选择)的数据上的性能,因此测试与验证集和训练集之间也是独立不重叠的,而且测试集不能提出对参数或者超参数的修改意见,只能作为评价网络性能的一个指标.
至此,我们可以将神经网络完整的训练过程归结为一下两个步骤:
-
训练普通参数.在训练集(给定超参数)上利用学习算法,训练普通参数,使得模型在训练集上的误差降低到可接受的程度(一般接近人类的水平).
-
'训练'超参数.在验证集上验证网络的generalization error(泛化能力),并根据模型性能对超参数进行调整.
重复1和2两个步骤,直至网络在验证集上取得较低的generalization error.此时完整的训练过程结束.在完成参数和超参数的训练后,在测试集上测试网络的性能.
2. 训练集、验证集和测试集的划分原则
本部分内容主要总结自Andrew Ng课程,课程中给出的原则是:
-
对于小规模样本集,常用的非配比例是trianing set/dev set/test set 6:2:2.例如共有10000个样本,则训练集分为6000个样本,验证集为2000样本,测试集为2000样本.
-
对于大规模样本集,则dev/test set的比例会减小很多,因为验证(比较)模型性能和测试模型性能一定的样本规模就足够了.例如共有1000000个样本,则训练集分为9980000个样本,验证集分为10000个样本,测试集分为10000个样本.
当我们不能获得足够的感兴趣的训练样本时,利用其他一些类似数据来训练网络时,该如何划分training, dev and test set?
例如我们在做一个识别猫的程序,我们的目标是识别用户拍照上传的猫的图片,但是我们能获得的APP上传的数据十分有限(例如10000张),所以准备通过利用网络爬虫下载的猫的图片(200000张)来协助训练网络.而因为网络爬取的图片与用户上传的图片有较大区别,这时候应该如何划分training/dev/test set?
- 一种方案是将app与web图片进行混合,然后按照大数据划分原则进行划分,即205000张training set,2500张dev set,2500张测试集.
- 另一种方案是,将app图片中2500张分给dyev set,2500张分给test set,5000张app图片和200000张web图片混合作为训练集.
Andrew Ng指出第二种方案更好,因为第二种方案dev set的数据全部来自app,与我们真正关心的数据具有相同的分布.而第一种方案,dev set中大概只有120张图片来自app,而剩下的大部分来自web,必然导致模型的评价准则偏移靶心.
当然,第二种方案会导致dev/test set与training set不同分布的问题,这会给误差分析带来麻烦.Andrew Ng给出的解决方案是在training set中划分出一部分作为train-dev set,该部分不用于训练,作为评价模型generalization error,而train-dev set与dev set之间的误差作为data mismatch error,表示数据分布不同引起的误差.