TensorFlow之梯度下降解决线性回归(7)

Operations(操作)

 

操作类别 操作举例
基本操作 split,rank,reshape,random_shuffle,slice,concat,...
逐元素的数学操作 add,subtract,multiply,div,greater,less,equal,exp,log,...
矩阵操作 eye,matmul,matrix_inverse,matrix_determinant,...
状态型操作 Variable,assign,assign_add,...
神经网络操作 sigmoid,relu,softmax,max_pool,dropout,conv2d,dynamic_rnn,...
保存/还原操作 train.saver.save,train.saver.restore
模型训练操作 train.GradientDescentOptimizer,train.AdamOptimizer,...

 

一些等价的操作

操作 等价操作
tf.add(a, b) a+b
tf.subtract(a, b) a-b
tf.multiply(a, b) a*b
tf.div(a, b) a/b
tf.mod(a,b) a%b
tf.square(a) a*a

 

 1 # -*- coding:utf-8 -*-
 2 
 3 """
 4 用梯度下降的优化方法来快速解决线性回归问题
 5 """
 6 
 7 import numpy as np
 8 import matplotlib.pyplot as plt
 9 import tensorflow as tf
10 
11 #构建数据
12 points_num = 100
13 vectors = []
14 #用 Numpy 的正态随机分布函数生成 100 个点
15 #这些点的 (x, y) 坐标值对应线性方程 y = 0.1 * x + 0.2
16 #权重 (Weight) 0.1, 偏差 (Bias) 0.2
17 for i in xrange(points_num):
18     x1 = np.random.normal(0.0, 0.66)
19     y1 = 0.1 * x1 + 0.2 + np.random.normal(0.0, 0.04)
20     vectors.append([x1, y1])
21 
22 x_data = [v[0] for v in vectors] #真实的点的 x 的坐标
23 y_data = [v[1] for v in vectors] #真实的点的 y 的坐标
24 
25 #图像 1: 展示100个随机数据点
26 plt.plot(x_data, y_data, "*", label="Original data") #红色星形的点
27 plt.title("Linear Regression using Gradient Descent")
28 plt.legend()
29 plt.show()
30 
31 
32 #构建线性回归模型
33 W = tf.Variable(tf.random_shuffle([1], -1.0, 1.0))  #初始化 Weight
34 b = tf.Variable(tf.zeros([1]))   #初始化 Bias
35 y = W * x_data + b               #模型计算出来的 y
36 
37 #定义损失函数 loss function 或 cost function(代价函数)
38 #对 Tensor 的所有维度计算 ( y - y_data) ^ 2 之和 / N
39 loss = tf.reduce_mean(tf.square(y-y_data))
40 
41 #用梯度下降的优化器来优化我们的 loss function
42 optimizer = tf.train.GradientDescentOptimizer(0.5)  #设置学习率 0.5
43 train = optimizer.minimize(loss)
44 
45 # 创建会话
46 sess = tf.Session()
47 
48 #初始化数据流图中的所有变量
49 init = tf.global_variables_initializer()
50 sess.run(init)
51 
52 #训练 20 步
53 for step in xrange(20):
54     # 优化每一步
55     sess.run(train)
56     #打印出每一步的损失,权重和偏差
57     print ("Step=%d, Loss=%f, [Weight=%f Bias=%f]") % (step, sess.run(loss), sess.run(W), sess.run(b))
58     
59 #图像 2 : 绘制所有的点并且绘制出最佳拟合的直线
60 plt.plot(x_data, y_data, "r*", label="Original data") #红色星形的点
61 plt.title("Linear Regression using Gradient Descent")
62 ply.plot(x_data, sess.run(W) * x_data + sess.run(b), label="Fitted line") #拟合的线
63 plt.legend()
64 plt.xlabel('x')
65 plt.ylabel('y')
66 plt.show()
67 
68 #关闭会话
69 sess.close()

 

posted on 2018-10-11 11:52  qiuqiu365  阅读(476)  评论(0编辑  收藏  举报