optim.SGD

链接:https://www.zhihu.com/question/585468191/answer/2905219147

optim.SGD是PyTorch中的一个优化器,其实现了随机梯度下降(Stochastic Gradient Descent,SGD)算法。在深度学习中,我们通常使用优化器来更新神经网络中的参数,以使得损失函数尽可能地小。

在PyTorch中使用optim.SGD优化器,一般需要指定以下参数:

  • params:需要更新的参数,通常为模型中的权重和偏置项。
  • lr:学习率,即每次参数更新时的步长。
  • momentum:动量,用来加速模型收敛速度,避免模型陷入局部最优解。
  • dampening:动量衰减,用来控制动量的衰减速度。
  • weight_decay:权重衰减,用来防止模型过拟合,即通过对权重的L2正则化来约束模型的复杂度。
  • nesterov:是否使用Nesterov动量。

在优化过程中,optim.SGD会根据当前的梯度和学习率计算出每个参数的更新量,并更新模型的参数。更新量的计算公式如下:

v(t) = momentum * v(t-1) - lr * (grad + weight_decay * w(t))
 w(t) = w(t-1) + v(t)

其中,v(t)表示当前时刻的速度,v(t-1)表示上一个时刻的速度,grad表示当前时刻的梯度,w(t)表示当前时刻的权重,w(t-1)表示上一个时刻的权重。

optim.SGD算法中的动量(momentum)可以看作是一个惯性项,用来在参数更新时保留之前的状态。当梯度方向发生改变时,动量能够加速模型收敛,并降低震荡。Nesterov动量可以在动量的基础上进一步优化模型的性能,它会先根据上一个时刻的速度来计算下一个时刻的梯度,然后再更新参数。

需要注意的是,在使用optim.SGD时,要适当调整学习率和动量等超参数,以便在训练中达到更好的性能。

e.g:

# 神经网络框架搭建
    net = InceptionBlock(num_classes=2, in_channels=1, init_weights=True)
    net.to(device)
    # 损失函数
    loss_function = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0.9)

    epochs = 50  # 迭代次数
    best_acc = 0.0  # 精度

    # 网络结构保存路径
    save_path = '../model/signal_classes.pth'

    train_steps = len(train_loader)
    for epoch in range(epochs):
        # 开启训练模式
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            signal_series, label_series = data
            signal_series = signal_series.reshape(batch_size, 1, 1200)
            # 优化器归零
            optimizer.zero_grad()
            # 模型计算
            model_predict = net(signal_series.to(device))
            # 计算损失函数
            loss = loss_function(model_predict, f.one_hot(label_series - 1, num_classes=2).float())
            # 反向传播计算
            loss.backward()
            # 优化器迭代
            optimizer.step()
            # 损失计算
            running_loss += loss.item()
            # 进度条计算
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

 

posted @ 2023-08-15 14:48  wangssd  阅读(132)  评论(0编辑  收藏  举报