保存模型 & 记录参数

保存的模型

在你提供的代码中,模型保存的条件如下:

  1. 验证阶段(_valid_epoch 方法)

    • 在每个 epoch 结束后,模型会进行验证,即使用验证数据集(self.valid_loader)计算验证指标(valid_metric)。
    • 通过 self.valid_step 方法计算每个 batch 的验证指标,最终将这些指标的平均值保存在 valid_metric 中。
  2. 与之前最佳指标进行比较(_metric_better 方法)

    • 当前的 valid_metric 会与之前的最佳指标(self.last_valid_metric)进行比较,比较方法由 _metric_better 函数决定。
    • 如果 self.config.metric_min_betterTrue,则较低的 valid_metric 更好(例如用于损失函数);否则,较高的 valid_metric 更好(例如用于准确率)。
  3. 保存模型的条件

    • 如果当前的 valid_metric 比之前的最佳指标更好,则会执行以下操作:
      • 重置耐心值:将 self.patience 重置为 self.config.patience
      • 保存模型:将模型检查点保存到文件中,并将其路径添加到 top-K 检查点列表中。
      • 检查点文件名包含当前的 epoch 和全局步骤,例如 epoch{self.epoch}_step{self.global_step}.ckpt
  4. 维护 Top-K 检查点(_maintain_topk_checkpoint 方法)

    • _maintain_topk_checkpoint 方法会维护一个基于验证指标的 top-K 最佳检查点列表。
    • 如果列表超过 top-K 限制(self.config.save_topk),则会删除指标最差的检查点,并从磁盘中删除该文件。
    • 这样可以确保只保留 top-K 表现最好的模型检查点,以节省存储空间。
  5. 停止训练

    • 如果验证指标未能改善(self.patience 减少),并且 self.patience 达到零,则训练循环会中断,从而提前结束训练过程。
模型保存条件总结:
  • 模型检查点会在当前验证指标(valid_metric)比历史最佳验证指标更好时保存。此外,系统只保留 top-K 表现最好的模型检查点,其余的会被删除以节省存储空间。

    def train_step(self, batch, batch_idx):
    batch['context_ratio'] = self.get_context_ratio()
    return self.share_step(batch, batch_idx, val=False)

    def valid_step(self, batch, batch_idx):
    batch['context_ratio'] = 0
    return self.share_step(batch, batch_idx, val=True)

    def share_step(self, batch, batch_idx, val=False):
    loss, seq_detail, structure_detail, dock_detail, pdev_detail = self.model(**batch)
    snll, aar = seq_detail
    struct_loss, xloss, bond_loss, sc_bond_loss = structure_detail
    dock_loss, interface_loss, ed_loss, r_ed_losses = dock_detail
    pdev_loss, prmsd_loss = pdev_detail

      log_type = 'Validation' if val else 'Train'
    
      self.log(f'Overall/Loss/{log_type}', loss, batch_idx, val)
    
      self.log(f'Seq/SNLL/{log_type}', snll, batch_idx, val)
      self.log(f'Seq/AAR/{log_type}', aar, batch_idx, val)
    
      self.log(f'Struct/StructLoss/{log_type}', struct_loss, batch_idx, val)
      self.log(f'Struct/XLoss/{log_type}', xloss, batch_idx, val)
      self.log(f'Struct/BondLoss/{log_type}', bond_loss, batch_idx, val)
      self.log(f'Struct/SidechainBondLoss/{log_type}', sc_bond_loss, batch_idx, val)
    
      self.log(f'Dock/DockLoss/{log_type}', dock_loss, batch_idx, val)
      self.log(f'Dock/SPLoss/{log_type}', interface_loss, batch_idx, val)
      self.log(f'Dock/EDLoss/{log_type}', ed_loss, batch_idx, val)
      for i, l in enumerate(r_ed_losses):
          self.log(f'Dock/edloss{i}/{log_type}', l, batch_idx, val)
    
      if pdev_loss is not None:
          self.log(f'PDev/PDevLoss/{log_type}', pdev_loss, batch_idx, val)
          self.log(f'PDev/PRMSDLoss/{log_type}', prmsd_loss, batch_idx, val)
    
      if not val:
          lr = self.config.lr if self.scheduler is None else self.scheduler.get_last_lr()
          lr = lr[0]
          self.log('lr', lr, batch_idx, val)
          self.log('context_ratio', batch['context_ratio'], batch_idx, val)
      return loss
    
posted @   GraphL  阅读(32)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
点击右上角即可分享
微信分享提示