[pytorch] 余弦退火+warmup实现调研
tl;dr: pytorch的 torch.optim.lr_scheduler.OneCycleLR
就很不错,能兼顾warmup和余弦学习率,也不用下载额外的包
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
import matplotlib.pyplot as plt
from timm import scheduler as timm_scheduler
from timm.scheduler.scheduler import Scheduler as timm_BaseScheduler
from torch.optim import Optimizer
from torch.optim import lr_scheduler
from transformers import get_cosine_schedule_with_warmup
def warmup_stable_decay(optimizer: Optimizer, max_steps, warmup_ratio=0.05, decay_ratio=0.2):
# WSD策略(Warmup-Stable-Decay) @Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations
n_warmup = max_steps * warmup_ratio
n_decay = max_steps * decay_ratio # n大于20k,可小于0.2
def inner(step):
if step < n_warmup:
return step / n_warmup # 线性增长
elif step <= max_steps - n_decay:
return 1 # 稳定
else:
t = (step - (max_steps - n_decay)) / n_decay # (0 -> 1)
t = 1 - t**0.5 # (1 -> 0)
return t
scheduler = lr_scheduler.LambdaLR(optimizer, inner)
return scheduler
def linear_warmup_drop(optimizer: Optimizer, steps_per_epoch: int, warmup_epochs: int, drop_epoch_list: None|tuple[int], drop_rate: float):
warmup_steps = max(1, int(steps_per_epoch * warmup_epochs))
rate = 1.0
def inner(step):
nonlocal rate
nonlocal drop_epoch_list
if step < warmup_steps: # linear warmup
return step / warmup_steps
elif not isinstance(drop_epoch_list, (list, tuple)) or len(drop_epoch_list)==0:
return rate
i = 0
for e in drop_epoch_list:
if step > e*steps_per_epoch:
i += 1
else:
break
rate = rate * (drop_rate**i)
drop_epoch_list = drop_epoch_list[i:]
return rate
scheduler = lr_scheduler.LambdaLR(optimizer, inner)
return scheduler
model = torch.nn.Linear(10, 1)
lr = 1# 32*4.5e-6
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
max_epoch = 10
steps_per_epoch = 100
total_steps = max_epoch * steps_per_epoch
mode = 'warmup_drop'
current_epoch = 0
match mode:
case 'cosineAnn':
steps_per_epoch = 1
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=4.5e-6)
case 'cosineAnnWarm':
'''
以T_0=5, T_mult=1为例:
T_0:学习率第一次回到初始值的epoch位置.
T_mult:这个控制了学习率回升的速度
- 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
- 5,10,15,20,25,.......处回到最大值
- 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
- 5,15,35,75,155,.......处回到最大值
example:
T_0=5, T_mult=1
'''
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=8, T_mult=1)
case 'cosineTimm':
steps_per_epoch = 1
scheduler = timm_scheduler.CosineLRScheduler(optimizer=optimizer, t_initial=max_epoch, lr_min=4.5e-6, warmup_t=1, warmup_lr_init=4.5e-6)
case 'cosineTorchLambda':
warmup_epoch = 2
warmup_factor = 1e-3
steps_per_epoch = 1
def f(current_epoch):
"""
:current_epoch epoch或者iteration
:return 根据step数返回一个学习率倍率因子
注意在训练开始之前,pytorch似乎会提前调用一次lr_scheduler.step()方法
"""
if current_epoch <= warmup_epoch:
alpha = float(current_epoch) / (warmup_epoch)
# warmup过程中lr倍率因子大小从warmup_factor -> 1
return warmup_factor * (1 - alpha) + alpha # 对于alpha的一个线性变换,alpha是关于x的一个反比例函数变化
else:
# warmup后lr的倍率因子从1 -> 0
# 参考deeplab_v2: Learning rate policy
return (1 - (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch)) ** 0.9
# (1-a/b)^0.9 b是当前这个epoch结束训练总共了多少次了(除去warmup),这个关系是指一个epcoch中
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
case 'step':
steps_per_epoch = 1
scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.5)
case 'oneCycle':
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=max_epoch, steps_per_epoch=steps_per_epoch, pct_start=0.1, final_div_factor=10)
case 'cosineTransformers':
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=steps_per_epoch, num_training_steps=max_epoch*steps_per_epoch)
case 'WSD':
scheduler = warmup_stable_decay(optimizer, total_steps)
case 'warmup_drop':
scheduler = linear_warmup_drop(optimizer, steps_per_epoch, 0.5, [2, 8], 0.1)
plt.figure()
# iters = 200
lr_history = []
for epoch in range(max_epoch):
for step in range(steps_per_epoch):
optimizer.step()
current_lr = optimizer.param_groups[0]['lr']
if isinstance(scheduler, timm_BaseScheduler):
scheduler.step(epoch)
else:
scheduler.step()
lr_history.append(current_lr)
print(lr_history)
plt.plot(range(len(lr_history)), lr_history)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Cosine Annealing with Warmup')
plt.show()
本文作者:心有所向,日复一日,必有精进
本文链接:https://www.cnblogs.com/Stareven233/p/17870826.html
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步