5-用PyTorch实现线性回归
下面是损失函数
下面是优化器
下面通过model.parameters()可以获得model中所有的参数
点击查看代码
import torch
from torch import device
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # 权重和偏置
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel() # 定义模型
# 定义损失函数
criterion = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 优化器对哪些参数进行更新
for epoch in range(5000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data) # 计算损失
print(epoch, loss.item())
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 参数更新
print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())
x_test = torch.tensor([4.0])
print('predict(4)=', model(x_test).item())