线回与非线回---sklearn--多元线性回归
前言:
前面用自写函数解决了多元问题,现在用sklearn库来解决多元线性问题
正文:
#老朋友,不介绍了
import numpy as np
from numpy import genfromtxt
#把线性回归模型库单独导出来
from sklearn import linear_model
#把画图工具库导出来
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#读入你的数据
data = genfromtxt(r"Delivery.csv",delimiter = ',')
print(data)
数据图片:
#切割数据,和原先一样的方法
#需要什么就切什么
x_data = data[:,:-1]
y_data = data[:,-1]
print(x_data)
print(y_data)
切分后的数据:
#创建回归模型
#带入切分好的数据
model = linear_model.LinearRegression()
model.fit(x_data,y_data)
#系数
print("model.coef:",model.coef_)
#截距
print("model.intercept:",model.intercept_)
#测试一下
x_test = [[102,4]]
predict = model.predict(x_test)
print("predict:",predict)
测试结果如下:
#使用add_subplot函数来创建3d面
ax = plt.figure().add_subplot(111,projection = '3d')
#描点并设置参数(函数介绍在上一篇内容中)
ax.scatter(x_data[:,0],x_data[:,1],y_data,c='r',marker = 'o',s=100)
x0 = x_data[:,0]
x1 = x_data[:,1]
#使用meshgrid函数生成网格矩阵
x0,x1 = np.meshgrid(x0,x1)
z = model.intercept_ + x0*model.coef_[0]+x1*model.coef_[1]
#画3d图
ax.plot_surface(x0,x1,z)
#设置坐标轴名称
ax.set_xlabel('miles')
ax.set_ylabel('num of delivers')
ax.set_zlabel('time')
#显示图像
plt.show()
总结:
效果还是不错的,而且这个图可以移动,方便查看!