pytorch 断点续训练
checkpoint = torch.load('.pth') net.load_state_dict(checkpoint['net']) criterion_mse = torch.nn.MSELoss().to(cfg.device) criterion_L1 = L1Loss() optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict= checkpoint['lr_schedule'] start_epoch = checkpoint['epoch'] for idx_epoch in range(start_epoch+1,80): scheduler.step() for idx_iter, () in enumerate(train_loader): _ = net() loss = criterion_mse(,) optimizer.zero_grad() loss.backward() optimizer.step() if idx_epoch % 1 == 0: checkpoint = { "net": net.state_dict(),#网络参数 'optimizer': optimizer.state_dict(),#优化器 "epoch": idx_epoch,#训练轮数 'lr_schedule': scheduler.state_dict()#lr如何变化 } torch.save(checkpoint,os.path.join(save_path, filename))
直接训练 a mean psnr: 28.160327919812364 a mean ssim: 0.8067064184409644 b mean psnr: 25.01364162100755 b mean ssim: 0.7600019779915981 c mean psnr: 25.83471135230011 c mean ssim: 0.7774989383731079 断点续训 a mean psnr: 28.15391601255439 a mean ssim: 0.8062857339309237 b mean psnr: 25.01115760689137 b mean ssim: 0.7596963993692107 c mean psnr: 25.842269038618145 c mean ssim: 0.7772710729947427
断点续训的效果基本和直接训练一致,但仍有些差别,后面会继续分析