[pytorch] 余弦退火+warmup实现调研

tl;dr: pytorch的 torch.optim.lr_scheduler.OneCycleLR 就很不错,能兼顾warmup和余弦学习率,也不用下载额外的包

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
import matplotlib.pyplot as plt
from timm import scheduler as timm_scheduler 
from timm.scheduler.scheduler import Scheduler as timm_BaseScheduler
from torch.optim import Optimizer
from torch.optim import lr_scheduler
from transformers import get_cosine_schedule_with_warmup


def warmup_stable_decay(optimizer: Optimizer, max_steps, warmup_ratio=0.05, decay_ratio=0.2):
  # WSD策略(Warmup-Stable-Decay) @Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations
  n_warmup = max_steps * warmup_ratio
  n_decay = max_steps * decay_ratio  # n大于20k,可小于0.2
  def inner(step):
    if step < n_warmup:
      return step / n_warmup  # 线性增长
    elif step <= max_steps - n_decay:
      return 1  # 稳定
    else:
      t = (step - (max_steps - n_decay)) / n_decay  # (0 -> 1)
      t = 1 - t**0.5  # (1 -> 0)
      return t
  scheduler = lr_scheduler.LambdaLR(optimizer, inner)
  return scheduler


def linear_warmup_drop(optimizer: Optimizer, steps_per_epoch: int, warmup_epochs: int, drop_epoch_list: None|tuple[int], drop_rate: float):
  warmup_steps = max(1, int(steps_per_epoch * warmup_epochs))
  rate = 1.0

  def inner(step):
    nonlocal rate
    nonlocal drop_epoch_list
    if step < warmup_steps:  # linear warmup
      return step / warmup_steps
    elif not isinstance(drop_epoch_list, (list, tuple)) or len(drop_epoch_list)==0:
      return rate
    i = 0
    for e in drop_epoch_list:
      if step > e*steps_per_epoch:
        i += 1
      else:
        break
    rate = rate * (drop_rate**i)
    drop_epoch_list = drop_epoch_list[i:]
    return rate

  scheduler = lr_scheduler.LambdaLR(optimizer, inner)
  return scheduler


model = torch.nn.Linear(10, 1)
lr = 1# 32*4.5e-6
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
max_epoch = 10
steps_per_epoch = 100
total_steps = max_epoch * steps_per_epoch
mode = 'warmup_drop'
current_epoch = 0


match mode:
  case 'cosineAnn':
    steps_per_epoch = 1
    scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=4.5e-6)
  case 'cosineAnnWarm':
    '''
      以T_0=5, T_mult=1为例:
      T_0:学习率第一次回到初始值的epoch位置.
      T_mult:这个控制了学习率回升的速度
          - 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
              - 5,10,15,20,25,.......处回到最大值
          - 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
              - 5,15,35,75,155,.......处回到最大值
      example:
          T_0=5, T_mult=1
      '''
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=8, T_mult=1)
  case 'cosineTimm':
    steps_per_epoch = 1
    scheduler = timm_scheduler.CosineLRScheduler(optimizer=optimizer, t_initial=max_epoch, lr_min=4.5e-6, warmup_t=1, warmup_lr_init=4.5e-6)
  case 'cosineTorchLambda':    
    warmup_epoch = 2
    warmup_factor = 1e-3
    steps_per_epoch = 1
    def f(current_epoch):
      """
      :current_epoch epoch或者iteration
      :return 根据step数返回一个学习率倍率因子
      注意在训练开始之前,pytorch似乎会提前调用一次lr_scheduler.step()方法
      """
      if current_epoch <= warmup_epoch:
          alpha = float(current_epoch) / (warmup_epoch)
          # warmup过程中lr倍率因子大小从warmup_factor -> 1
          return warmup_factor * (1 - alpha) + alpha  # 对于alpha的一个线性变换,alpha是关于x的一个反比例函数变化
      else:
          # warmup后lr的倍率因子从1 -> 0
          # 参考deeplab_v2: Learning rate policy
          return (1 - (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch)) ** 0.9
      # (1-a/b)^0.9 b是当前这个epoch结束训练总共了多少次了(除去warmup),这个关系是指一个epcoch中
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
  case 'step':
    steps_per_epoch = 1
    scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.5)
  case 'oneCycle':
    scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=max_epoch, steps_per_epoch=steps_per_epoch, pct_start=0.1, final_div_factor=10)
  case 'cosineTransformers':
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=steps_per_epoch, num_training_steps=max_epoch*steps_per_epoch)
  case 'WSD':
    scheduler = warmup_stable_decay(optimizer, total_steps)
  case 'warmup_drop':
    scheduler = linear_warmup_drop(optimizer, steps_per_epoch, 0.5, [2, 8], 0.1)

plt.figure()
# iters = 200
lr_history = []
for epoch in range(max_epoch):
  for step in range(steps_per_epoch):
    optimizer.step()
    current_lr = optimizer.param_groups[0]['lr']
    if isinstance(scheduler, timm_BaseScheduler):
      scheduler.step(epoch)
    else:
      scheduler.step()
    lr_history.append(current_lr)

print(lr_history)
plt.plot(range(len(lr_history)), lr_history)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Cosine Annealing with Warmup')
plt.show()

本文作者:心有所向,日复一日,必有精进

本文链接:https://www.cnblogs.com/Stareven233/p/17870826.html

版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。

posted @   NoNoe  阅读(1124)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
💬
评论
📌
收藏
💗
关注
👍
推荐
🚀
回顶
收起
  1. 1 Relaxロウきゅーぶ 渡辺剛
  2. 2 カントリーマーチ 栗コーダーカルテット
  3. 3 BGM-M7 かみむら周平
  4. 4 八百万の風が吹く Foxtail-Grass Studio
  5. 5 雲流れ Foxtail-Grass Studio
  6. 6 Melody 梶浦由記
  7. 7 ロック风アレンジ Angel Beats
  8. 8 ヨスガノソラ メインテーマ -遠い空へ- Bruno Wen-li
  9. 9 Servante du feu Matthieu Ladouce
  10. 10 Lost my pieces (Piano Ver.) 橋本由香利
  11. 11 潮鳴り 折戸伸治
  12. 12 雪風 Foxtail-Grass Studio
  13. 13 Bloom of Youth 清水淳一
  14. 14 落月随山隐 饭碗的彼岸,夜莺与玫瑰
  15. 15 Autumn Journey Eric Chiryoku
  16. 16 Alpha C418
  17. 17 Money之歌 神楽Mea
カントリーマーチ - 栗コーダーカルテット
00:00 / 00:00
An audio error has occurred, player will skip forward in 2 seconds.

暂无歌词