tensorflow2.0--学习率调整(线性warmup,指数衰减)

class WarmUpLineDecayScheduler(keras.callbacks.Callback):
def __init__(self, lr_max,lr_min, warm_step,sum_step,bat):
super(WarmUpLineDecayScheduler, self).__init__()
self.lr_max = lr_max
self.lr_min = lr_min
self.warm_step = warm_step
self.sum_step = sum_step
self.bat = bat

def on_train_begin(self, batch, logs=None):
self.init_lr = K.get_value(self.model.optimizer.lr)

def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch
def on_batch_end(self,batch, logs=None):
step = self.epoch*self.sum_step+batch
# print('step:',step)
learning_decay_steps = 1
learning_decay_rate = 0.999
warm_lr = self.lr_max * (step / self.warm_step)
decay_lr = max(self.init_lr * tf.pow(learning_decay_rate , (step / learning_decay_steps)),self.lr_min)
if step < self.warm_step:
lr = warm_lr
else:
lr =decay_lr
K.set_value(self.model.optimizer.lr, lr)

warm_up = WarmUpLineDecayScheduler(lr_rate,lr_min, warm_step=warm_epoch*int(train_x.shape[0]/bat),sum_step=train_x.shape[0]/bat,bat=bat)

s_model.fit(train_db, epochs=epochs, validation_data=test_db, callbacks=[warm_up])

 

posted @ 2022-07-14 19:34  山…隹  阅读(491)  评论(0编辑  收藏  举报