多项式回归
多项式回归
import torch
import numpy
def make_features(x):
'''获取 [x, x^2, x^3]...的矩阵'''
x = x.unsqueeze(1) #将一维数据变为(n,1)二维矩阵形式
return torch.cat([x ** i for i in range(1, 4)], 1) #按列拼接
def f(x):
W_target = torch.Tensor([0.5, 3., 2.4]).unsqueeze(1)
b_target = torch.Tensor([0.9])
return x.mm(W_target) + b_target # 表达式:f(x) = X * W_target + b_target
batch_size=32
random = torch.randn(batch_size)
def get_batch(batch_size=32):
''' 获取32个数据对:(x, f(x)) '''
# random = torch.randn(batch_size)
x = make_features(random)
y = f(x)
return torch.autograd.Variable(x), torch.autograd.Variable(y)
class poly_model(torch.nn.Module):
''' 定义多项式模型 '''
def __init__(self):
super(poly_model, self).__init__()
self.poly = torch.nn.Linear(3,1) #输入3维[x, x^2, x^3],输出1维y
def forward(self, x):
out = self.poly(x)
return out
model = poly_model()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epoch = 0
#get data
batch_x, batch_y = get_batch()
while True:
#forward
out = model(batch_x)
loss = criterion(out, batch_y)
print_loss = loss.data
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch += 1
if print_loss < 1e-2:
print('epoch:',epoch)
break
print(list(model.parameters())) #打印最后学习到的参数w, b
学习结果如下:
epoch: 1179
[Parameter containing:
tensor([[0.5059, 3.0444, 2.3960]], requires_grad=True), Parameter containing:
tensor([0.7748], requires_grad=True)]
绘制曲线:
import matplotlib.pyplot as plt
model.eval()
predict = model(torch.autograd.Variable(batch_x))
predict = predict.data.numpy()
plt.plot(sorted(random), sorted(batch_y.numpy()), 'ro', label='real curve')
plt.plot(sorted(random), sorted(predict.flatten()), label= 'Fitting curve')
plt.legend()
plt.show()
本文来自博客园,作者:aJream,转载请记得标明出处:https://www.cnblogs.com/ajream/p/15383549.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人