TensorFlow之梯度下降解决线性回归(7)
Operations(操作)
操作类别 | 操作举例 |
基本操作 | split,rank,reshape,random_shuffle,slice,concat,... |
逐元素的数学操作 | add,subtract,multiply,div,greater,less,equal,exp,log,... |
矩阵操作 | eye,matmul,matrix_inverse,matrix_determinant,... |
状态型操作 | Variable,assign,assign_add,... |
神经网络操作 | sigmoid,relu,softmax,max_pool,dropout,conv2d,dynamic_rnn,... |
保存/还原操作 | train.saver.save,train.saver.restore |
模型训练操作 | train.GradientDescentOptimizer,train.AdamOptimizer,... |
一些等价的操作
操作 | 等价操作 |
tf.add(a, b) | a+b |
tf.subtract(a, b) | a-b |
tf.multiply(a, b) | a*b |
tf.div(a, b) | a/b |
tf.mod(a,b) | a%b |
tf.square(a) | a*a |
1 # -*- coding:utf-8 -*- 2 3 """ 4 用梯度下降的优化方法来快速解决线性回归问题 5 """ 6 7 import numpy as np 8 import matplotlib.pyplot as plt 9 import tensorflow as tf 10 11 #构建数据 12 points_num = 100 13 vectors = [] 14 #用 Numpy 的正态随机分布函数生成 100 个点 15 #这些点的 (x, y) 坐标值对应线性方程 y = 0.1 * x + 0.2 16 #权重 (Weight) 0.1, 偏差 (Bias) 0.2 17 for i in xrange(points_num): 18 x1 = np.random.normal(0.0, 0.66) 19 y1 = 0.1 * x1 + 0.2 + np.random.normal(0.0, 0.04) 20 vectors.append([x1, y1]) 21 22 x_data = [v[0] for v in vectors] #真实的点的 x 的坐标 23 y_data = [v[1] for v in vectors] #真实的点的 y 的坐标 24 25 #图像 1: 展示100个随机数据点 26 plt.plot(x_data, y_data, "*", label="Original data") #红色星形的点 27 plt.title("Linear Regression using Gradient Descent") 28 plt.legend() 29 plt.show() 30 31 32 #构建线性回归模型 33 W = tf.Variable(tf.random_shuffle([1], -1.0, 1.0)) #初始化 Weight 34 b = tf.Variable(tf.zeros([1])) #初始化 Bias 35 y = W * x_data + b #模型计算出来的 y 36 37 #定义损失函数 loss function 或 cost function(代价函数) 38 #对 Tensor 的所有维度计算 ( y - y_data) ^ 2 之和 / N 39 loss = tf.reduce_mean(tf.square(y-y_data)) 40 41 #用梯度下降的优化器来优化我们的 loss function 42 optimizer = tf.train.GradientDescentOptimizer(0.5) #设置学习率 0.5 43 train = optimizer.minimize(loss) 44 45 # 创建会话 46 sess = tf.Session() 47 48 #初始化数据流图中的所有变量 49 init = tf.global_variables_initializer() 50 sess.run(init) 51 52 #训练 20 步 53 for step in xrange(20): 54 # 优化每一步 55 sess.run(train) 56 #打印出每一步的损失,权重和偏差 57 print ("Step=%d, Loss=%f, [Weight=%f Bias=%f]") % (step, sess.run(loss), sess.run(W), sess.run(b)) 58 59 #图像 2 : 绘制所有的点并且绘制出最佳拟合的直线 60 plt.plot(x_data, y_data, "r*", label="Original data") #红色星形的点 61 plt.title("Linear Regression using Gradient Descent") 62 ply.plot(x_data, sess.run(W) * x_data + sess.run(b), label="Fitted line") #拟合的线 63 plt.legend() 64 plt.xlabel('x') 65 plt.ylabel('y') 66 plt.show() 67 68 #关闭会话 69 sess.close()