Loading

Logistic 回归测试代码

简单概念

Logistic 回归是一种经典的分类方法,多用于二分类的问题。通过寻找合适的分类函数,用以对输入的数据进行预测,并给出判断结果。使用 sigmoid 函数(逻辑函数)将线性模型的结果压缩到 [0, 1] 之间,使输出的结果具有概率意义,实现输入值到输出概率的转换。

sigmoid 函数:$ g(z) = \frac{1}{1+e^{-z}} $

img

测试代码

目标:大于 3 的数输出结果为 1,小于等于 3 的数输出结果为 0。

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # cuda 加速

# 测试数据,≤3 结果为 0,>3 结果为 1
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]])
y_data = torch.Tensor([[0], [0], [0], [1], [1], [1]])


# Logistic 回归测试模型
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)  # 线性层

    def forward(self, x):
        output = torch.nn.functional.sigmoid(self.linear(x))
        return output


model = LogisticRegressionModel().to(device)

# 损失函数和优化器
criterion = torch.nn.BCELoss(reduction='sum').to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练
for epoch in range(1, 40001):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    if epoch % 1000 == 0:
        print(f'epoch: {epoch}, loss: {loss.item():.4f}')

    optimizer.zero_grad()  # 梯度清零
    loss.backward()        # 反向传播
    optimizer.step()       # 利用优化器对参数 x 进行更新

# 预测结果
test_data = [-2.0, -1.0, 0.0, 7.0, 8.0, 9.0]
for data in test_data:
    x_test = torch.Tensor([[data]])
    y_test = model(x_test)
    print(f'{x_test.item()} pred: {y_test.item():.8f}')

输出结果

epoch: 1000, loss: 1.2079
epoch: 2000, loss: 0.8453
epoch: 3000, loss: 0.6859
epoch: 4000, loss: 0.5891
epoch: 5000, loss: 0.5216
epoch: 6000, loss: 0.4706
epoch: 7000, loss: 0.4302
epoch: 8000, loss: 0.3970
epoch: 9000, loss: 0.3691
epoch: 10000, loss: 0.3452
epoch: 11000, loss: 0.3245
epoch: 12000, loss: 0.3062
epoch: 13000, loss: 0.2900
epoch: 14000, loss: 0.2755
epoch: 15000, loss: 0.2624
epoch: 16000, loss: 0.2506
epoch: 17000, loss: 0.2397
epoch: 18000, loss: 0.2298
epoch: 19000, loss: 0.2207
epoch: 20000, loss: 0.2123
epoch: 21000, loss: 0.2046
epoch: 22000, loss: 0.1973
epoch: 23000, loss: 0.1906
epoch: 24000, loss: 0.1843
epoch: 25000, loss: 0.1784
epoch: 26000, loss: 0.1729
epoch: 27000, loss: 0.1677
epoch: 28000, loss: 0.1628
epoch: 29000, loss: 0.1582
epoch: 30000, loss: 0.1538
epoch: 31000, loss: 0.1497
epoch: 32000, loss: 0.1458
epoch: 33000, loss: 0.1421
epoch: 34000, loss: 0.1385
epoch: 35000, loss: 0.1352
epoch: 36000, loss: 0.1319
epoch: 37000, loss: 0.1289
epoch: 38000, loss: 0.1260
epoch: 39000, loss: 0.1232
epoch: 40000, loss: 0.1205
-2.0 pred: 0.00000000
-1.0 pred: 0.00000000
0.0 pred: 0.00000000
7.0 pred: 1.00000000
8.0 pred: 1.00000000
9.0 pred: 1.00000000
posted @ 2023-05-03 16:24  滑稽果  阅读(50)  评论(0编辑  收藏  举报
浏览器标题切换
浏览器标题切换end