梯度下降算法
手工实现方式
import numpy as np
import matplotlib.pyplot as plt
# 用来加载中文
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
def loadDataSet(filename):
'''加载文件,将feature存在X中,y存在Y中'''
X = []
Y = []
with open(filename, 'rb') as f:
for idx, line in enumerate(f):
line = line.decode('utf-8').strip()
if not line:
continue
eles = line.split()
if idx == 0:
numFeature = len(eles)
eles = list(map(float, eles)) # 将数据转换成float型
X.append(eles[:-1]) # 除最后一列都是feature,append(list)
Y.append([eles[-1]]) # 最后一列是实际值,同上
return np.array(X), np.array(Y) # 将X,Y列表转化成矩阵
def h(theta, X):
'''定义模型函数'''
return np.dot(X, theta) # 此时的X为处理后的X
def J(theta, X, Y):
'''定义代价函数'''
m = len(X)
return np.sum(np.dot((h(theta, X) - Y).T, (h(theta, X) - Y)) / (2 * m))
def bgd(alpha, maxloop, epsilon, X, Y):
'''定义梯度下降公式,其中alpha为学习率控制步长,maxloop为最大迭代次数,epsilon为阈值控制迭代(判断收敛)'''
m, n = X.shape # m为样本数,n为特征数,在这里为2
# 初始化参数为零
theta = np.zeros((2, 1))
count = 0 # 记录迭代次数
converged = False # 是否收敛标志
cost = np.inf # 初始化代价为无穷大
costs = [] # 记录每一次迭代的代价值
thetas = {0: [theta[0, 0]], 1: [theta[1, 0]]} # 记录每一轮theta的更新
while count <= maxloop:
if converged:
break
# 更新theta
count = count + 1
# 单独计算
# theta0 = theta[0,0] - alpha / m * (h(theta, X) - Y).sum()
# theta1 = theta[1,0] - alpha / m * (np.dot(X[:,1][:,np.newaxis].T,(h(theta, X) - Y))).sum() # 重点注意一下
# 同步更新
# theta[0,0] = theta0
# theta[1,0] = theta1
# thetas[0].append(theta0)
# thetas[1].append(theta1)
# 一起计算
theta = theta - alpha / (1.0 * m) * np.dot(X.T, (h(theta, X) - Y))
# X.T : n*m , h(theta, Y) : m*1 , np.dot(X.T, (h(theta, X)- Y)) : n*1
# 同步更新
thetas[0].append(theta[0])
thetas[1].append(theta[1])
# 更新当前cost
cost = J(theta, X, Y)
costs.append(cost)
# 如果收敛,则不再迭代
if cost < epsilon:
converged = True
return theta, costs, thetas
X, Y = loadDataSet('D:\python_project\HandlePythonExample\day01\data\ex1.txt')
print(X.shape)
print(Y.shape)
m, n = X.shape
X = np.concatenate((np.ones((m, 1)), X), axis=1) # 将第一列为1的矩阵,与原X相连
print(X.shape)
alpha = 0.02 # 学习率
maxloop = 1500 # 最大迭代次数
epsilon = 0.01 # 收敛判断条件
result = bgd(alpha, maxloop, epsilon, X, Y)
theta, costs, thetas = result # 最优参数保存在theta中,costs保存每次迭代的代价值,thetas保存每次迭代更新的theta值
print(theta)
# 到此,参数学习出来了,模型也就定下来了,若要预测新的实例,进行以下即可
# Y_predict = h(theta, X_predict)
# 以下为训练集的预测值
XCopy = X.copy()
XCopy.sort(0) # axis=0 表示列内排序
# print(XCopy)
# print(Y)
yHat = h(theta, XCopy)
# print(XCopy[:,1].shape, yHat.shape, theta.shape)
# 绘制回归直线
plt.xlabel(u'城市人口(万)')
plt.ylabel(u'利润(万元)')
plt.plot(XCopy[:, 1], yHat, color='r')
plt.scatter(X[:, 1].flatten(), Y.T.flatten()) # 画散点图
plt.show()
调用库实现梯度下降法
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
# 用来加载中文
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
def h(theta, X):
'''定义模型函数'''
return np.dot(X, theta) # 此时的X为处理后的X
def loadDataSet(filename):
'''加载文件,将feature存在X中,y存在Y中'''
X = []
Y = []
with open(filename, 'rb') as f:
for idx, line in enumerate(f):
line = line.decode('utf-8').strip()
if not line:
continue
eles = line.split()
if idx == 0:
numFeature = len(eles)
eles = list(map(float, eles)) # 将数据转换成float型
X.append(eles[:-1]) # 除最后一列都是feature,append(list)
Y.append([eles[-1]]) # 最后一列是实际值,同上
return np.array(X), np.array(Y) # 将X,Y列表转化成矩阵
X, Y = loadDataSet('D:\python_project\HandlePythonExample\day01\data\ex1.txt')
print(X.shape)
print(Y.shape)
reg = LinearRegression().fit(X, Y)
print(reg.get_params())
print(reg.coef_.tolist()[0])
print(reg.intercept_)
print(reg.predict(np.array([[500]])))
m, n = X.shape
X = np.concatenate((np.ones((m, 1)), X), axis=1) # 将第一列为1的矩阵,与原X相连
XCopy = X.copy()
XCopy.sort(0) # axis=0 表示列内排序
a = (reg.coef_.tolist()[0])[0]
b = (reg.intercept_.tolist()[0])
theat = np.array([[b],[a]])
yHat = h(theat, XCopy)
plt.xlabel(u'城市人口(万)')
plt.ylabel(u'利润(万元)')
plt.plot(XCopy[:, 1], yHat, color='r')
plt.scatter(X[:, 1].flatten(), Y.T.flatten()) # 画散点图
plt.show()
posted on 2019-07-22 14:55 Indian_Mysore 阅读(150) 评论(0) 编辑 收藏 举报