多元线性回归py实现
多元线性回归实验py实现
采用最简单的方差损失函数,和一元线性回归类似的,每次求预测和真实的方差,求其偏导数向它梯度方向修改参数。
from matplotlib import projections
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import axes3d
# 读入
train = np.loadtxt('data2.csv',delimiter=',',dtype = 'int')
train_x = train[:,0]
train_y = train[:,1]
train_z = train[:,2]
# 生成原始变量函数
ax=plt.subplot(111,projection='3d')
ax.scatter(train_x, train_y, train_z ,c="r")
plt.show()
# 随机生成 theata
theata0 = np.random.rand()
theata1 = np.random.rand()
theata2 = np.random.rand()
# 预测函数 f(x1,x2) = ax1 + bx2 + c
def f(x,y) :
return theata0 + theata1 * x + theata2 * y
# 目标函数
def E(x,y,z):
return 0.5 * np.sum((z - f(x,y)) ** 2)
# 标准化函数
def standardize(x):
mu = x.mean()
sigma = x.std()
return (x - mu) / sigma
train_x_std = standardize(train_x)
train_y_std = standardize(train_y)
train_z_std = standardize(train_z)
# 生成标准化后的函数
ax=plt.subplot(111,projection='3d')
ax.scatter(train_x_std, train_y_std, train_z_std ,c="r")
plt.show()
ETA = 1e-3 # 学习率
diff = 1 # 误差大小
count = 0 # 迭代次数
cnt = [] # 次数列表
errs = [] # 误差列表
error = E(train_x_std,train_y_std,train_y_std)
while diff > 1e-2:
#更新结果保存到临时变量
tmp_theata0 = theata0 - ETA * np.sum(f(train_x_std,train_y_std) - train_z_std)
tmp_theata1 = theata1 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_x_std)
tmp_theata2 = theata2 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_y_std)
theata0 = tmp_theata0
theata1 = tmp_theata1
theata2 = tmp_theata2
current_error = E(train_x_std,train_y_std,train_z_std)
diff = error - current_error
error = current_error
cnt.append(count)
errs.append(current_error)
count += 1
log = '第 {} 次 : theta0 = {:.3f}, theta1 = {:.3f},theta2 = {:.3f}, 差值 = {:.4f}'
print(log.format(count, theata0, theata1, theata2, diff))
ax = plt.subplot(111,projection = '3d')
ax.scatter(train_x_std,train_y_std,train_z_std,c = 'r')
x = np.arange(-3,3,0.1)
y = np.arange(-3,3,0.1)
x,y = np.meshgrid(x,y)
z = f(x,y)
surf = ax.plot_surface(x, y, z, cmap=cm.Blues,linewidth=1, antialiased=False)
plt.show()
plt.plot(cnt,errs)
plt.show()
数据集如下
2104,3,399900
1600,3,329900
2400,3,369000
1416,2,232000
3000,4,539900
1985,4,299900
1534,3,314900
1427,3,198999
1380,3,212000
1494,3,242500
1940,4,239999
2000,3,347000
1890,3,329999
4478,5,699900
1268,3,259900
2300,4,449900
1320,2,299900
1236,3,199900
2609,4,499998
3031,4,599000
1767,3,252900
1888,2,255000
1604,3,242900
1962,4,259900
3890,3,573900
1100,3,249900
1458,3,464500
2526,3,469000
2200,3,475000
2637,3,299900
1839,2,349900
1000,1,169900
2040,4,314900
3137,3,579900
1811,4,285900
1437,3,249900
1239,3,229900
2132,4,345000
4215,4,549000
2162,4,287000
1664,2,368500
2238,3,329900
2567,4,314000
1200,3,299000
852,2,179900
1852,4,299900
1203,3,239500