一点飞鸿影下

孤村落日残霞,轻烟老树寒鸦,一点飞鸿影下。 青山绿水,白草红叶黄花。

导航

梯度下降解决线性回归问题(带超详细注释)

学习笔记


# -*- coding=utf-8 -*-


# 梯度下降方法快速解决线性回归问题

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# -----------------------------------------------模拟准备直线 y=0.1 * x + b 附近100个点的坐标---------------------------
# 构建数据
points_num = 100
vectors = []
# 使用numpy 正态随机函数 随机初始化100个点
# 点的坐标(x,y)值对应一个线性方程  y = 0.1 * x + 0.2
# 权重(Weight  0.1  偏差(Bias) 0.2
for i in range(points_num):
    # np.random.normal(mean,stdev,size) 返回均值为mean,标准差为stdev,长度为size的列表
    x1 = np.random.normal(0.0,0.66) 
    y1 = 0.1 * x1 + 0.5 + np.random.normal(0.0,0.04)
    vectors.append([x1,y1])


# -----------------------------------------------使用plt画出100个点的位置-----------------------------------------------
# 获取100个点 x,y坐标的集合
x_data = [v[0] for v in vectors] # 真实的点的 x坐标
y_data = [v[1] for v in vectors] # 真实的点的 y坐标

# 图像1 :展示100个随机点
# plot示例: https://matplotlib.org/gallery.html
plt.plot(x_data,y_data, 'r*', label='Original data') # 红色星型点
plt.title("Linear Regression using Gradient Descent")
plt.savefig('1.png')
plt.show()


#----------------------------------------------------------构建模型-----------------------------------------------------
# 构建线性回归模型
W = tf.Variable(tf.random_uniform([1], -1.0,1.0)) # 初始化权重
b = tf.Variable(tf.zeros([1]))# 初始化Bias
y = W * x_data + b  # 模型计算出来的y

# 定义损失函数 loss function(损失函数)  或  cost function(代价函数)
# 对 Tensor 所有维度计算(y - y_data)^2 之和 / n
loss = tf.reduce_mean(tf.square(y - y_data))

# 用梯度下降优化器来优化 loss function() 
optimizer = tf.train.GradientDescentOptimizer(0.5) #设置学习率
train = optimizer.minimize(loss)

# 创建会话
sess = tf.Session()
# 初始化数据流图中的所有变量
init = tf.global_variables_initializer()
sess.run(init)

#----------------------------------------------------------开始训练---------------------------------------------------------
# 训练  20 步
for step in range(20):
    # 优化每一步
    sess.run(train)
    # 打印出每一步的损失,权重,偏差
    print(step,sess.run(loss),sess.run(b))


#----------------------------------------------------------画图展示---------------------------------------------------------
# 绘制所有点 会指出最佳拟合的直线
plt.plot(x_data,y_data, 'r*', label='Original data') # 红色星型点
plt.title("Linear Regression using Gradient Descent")
plt.plot(x_data,sess.run(W)*x_data+sess.run(b), label="line")
plt.xlabel('X')
plt.ylabel('y')
plt.savefig('1.png')
sess.close()

posted on 2019-09-26 01:14  一点飞鸿影下  阅读(380)  评论(0编辑  收藏  举报