3、ModelCheckPoint
1、导包
1 from tensorflow.keras.callbacks import ModelCheckpoint
2、介绍
在训练机器学习模型时,经常需要缓存模型。
ModelCheckpoint
是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。
它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。
在每个epoch结束作为回调函数,保存模型。
3、参数介绍
3.1、monitor='val_loss', 我们想要监视的指标 ,val_acc或val_loss。
3.2、dirpath='my/path/', 模型缓存目录
3.3、verbose: 详细信息模式,0 或者1。 0为不打印输出信息,1为打印
3.4、save_best_only: True,将只保存在验证集上性能最好的模型mode: {auto, min, max} 的其中之一。是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。
对于val_acc,模式就会是max;而对于val_loss,模式就需要是min。在auto模式中,方式会自动从被监测的数据的名字中判断出来。
3.5、save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
3.6、period: 每个检查点之间的间隔(训练轮数)。