pytorch中学习率的调整方法

一、手动法

二、利用lr_scheduler()提供的集中衰减函数

2.1 利用lr_lambda函数

具体使用:

from torch.optim import SGD, lr_scheduler
import matplotlib.pyplot as plt
from torch.nn import Module, Sequential, Linear, CrossEntropyLoss


# 定义网络模型
class model(Module):
def __init__(self):
super(model, self).__init__()
self.fc = Sequential(
Linear(1,10)
)

def forward(self, input):
output = self.fc(input)
return output

# 初始化网络模型
Model = model()
# 定义损失函数
Loss = CrossEntropyLoss()
# 创建优化器
lr = 0.01
optimizer = SGD(Model.parameters(), lr=lr)
# 定义一个list保存学习率
lr_list = []

# 定义学习率与轮数关系的函数
lambda1 = lambda epoch:0.95 ** epoch # 学习率 = 0.95**(轮数)
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)

for epoch in range(100):
print("epoch={}, lr={}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])

plt.plot(range(100),lr_list,color = 'r',label = 'LambdaLR')
plt.ylabel('learning rate')
plt.xlabel('epoch')
plt.legend()
plt.show()

posted on   陈酉西  阅读(52)  评论(0编辑  收藏  举报

相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

点击右上角即可分享
微信分享提示