保存模型 & 记录参数
保存的模型
在你提供的代码中,模型保存的条件如下:
-
验证阶段(
_valid_epoch
方法):- 在每个 epoch 结束后,模型会进行验证,即使用验证数据集(
self.valid_loader
)计算验证指标(valid_metric
)。 - 通过
self.valid_step
方法计算每个 batch 的验证指标,最终将这些指标的平均值保存在valid_metric
中。
- 在每个 epoch 结束后,模型会进行验证,即使用验证数据集(
-
与之前最佳指标进行比较(
_metric_better
方法):- 当前的
valid_metric
会与之前的最佳指标(self.last_valid_metric
)进行比较,比较方法由_metric_better
函数决定。 - 如果
self.config.metric_min_better
为True
,则较低的valid_metric
更好(例如用于损失函数);否则,较高的valid_metric
更好(例如用于准确率)。
- 当前的
-
保存模型的条件:
- 如果当前的
valid_metric
比之前的最佳指标更好,则会执行以下操作:- 重置耐心值:将
self.patience
重置为self.config.patience
。 - 保存模型:将模型检查点保存到文件中,并将其路径添加到 top-K 检查点列表中。
- 检查点文件名包含当前的 epoch 和全局步骤,例如
epoch{self.epoch}_step{self.global_step}.ckpt
。
- 重置耐心值:将
- 如果当前的
-
维护 Top-K 检查点(
_maintain_topk_checkpoint
方法):_maintain_topk_checkpoint
方法会维护一个基于验证指标的 top-K 最佳检查点列表。- 如果列表超过 top-K 限制(
self.config.save_topk
),则会删除指标最差的检查点,并从磁盘中删除该文件。 - 这样可以确保只保留 top-K 表现最好的模型检查点,以节省存储空间。
-
停止训练:
- 如果验证指标未能改善(
self.patience
减少),并且self.patience
达到零,则训练循环会中断,从而提前结束训练过程。
- 如果验证指标未能改善(
模型保存条件总结:
-
模型检查点会在当前验证指标(
valid_metric
)比历史最佳验证指标更好时保存。此外,系统只保留 top-K 表现最好的模型检查点,其余的会被删除以节省存储空间。def train_step(self, batch, batch_idx):
batch['context_ratio'] = self.get_context_ratio()
return self.share_step(batch, batch_idx, val=False)def valid_step(self, batch, batch_idx):
batch['context_ratio'] = 0
return self.share_step(batch, batch_idx, val=True)def share_step(self, batch, batch_idx, val=False):
loss, seq_detail, structure_detail, dock_detail, pdev_detail = self.model(**batch)
snll, aar = seq_detail
struct_loss, xloss, bond_loss, sc_bond_loss = structure_detail
dock_loss, interface_loss, ed_loss, r_ed_losses = dock_detail
pdev_loss, prmsd_loss = pdev_detaillog_type = 'Validation' if val else 'Train' self.log(f'Overall/Loss/{log_type}', loss, batch_idx, val) self.log(f'Seq/SNLL/{log_type}', snll, batch_idx, val) self.log(f'Seq/AAR/{log_type}', aar, batch_idx, val) self.log(f'Struct/StructLoss/{log_type}', struct_loss, batch_idx, val) self.log(f'Struct/XLoss/{log_type}', xloss, batch_idx, val) self.log(f'Struct/BondLoss/{log_type}', bond_loss, batch_idx, val) self.log(f'Struct/SidechainBondLoss/{log_type}', sc_bond_loss, batch_idx, val) self.log(f'Dock/DockLoss/{log_type}', dock_loss, batch_idx, val) self.log(f'Dock/SPLoss/{log_type}', interface_loss, batch_idx, val) self.log(f'Dock/EDLoss/{log_type}', ed_loss, batch_idx, val) for i, l in enumerate(r_ed_losses): self.log(f'Dock/edloss{i}/{log_type}', l, batch_idx, val) if pdev_loss is not None: self.log(f'PDev/PDevLoss/{log_type}', pdev_loss, batch_idx, val) self.log(f'PDev/PRMSDLoss/{log_type}', prmsd_loss, batch_idx, val) if not val: lr = self.config.lr if self.scheduler is None else self.scheduler.get_last_lr() lr = lr[0] self.log('lr', lr, batch_idx, val) self.log('context_ratio', batch['context_ratio'], batch_idx, val) return loss
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY