第六节 实现简单的线性回归
import tensorflow as tf import os # 定义命令行参数,"max_step", 100, "模型训练的步数", 三个参数都是必须的,max_step在程序中引用的变量名,100是给第一个参数设置的默认值,第三个参数是第一个参数的参数说明 tf.app.flags.DEFINE_integer("max_step", 100, "模型训练的步数") # 定义获取命令行参数的名字,在程序中调用aaa.max_step aaa = tf.app.flags.FLAGS def myregression(): """实现一个线性回归""" with tf.variable_scope('data'): # 定义变量作用域,使代码结构更清晰,而且在TensorBoard可视化中显示更清晰 # 1.构造数据,x 特征值 [100, 1] y 目标值 [100] x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name='x_data') # 矩阵相乘必须是二维的 y_true = tf.matmul(x, [[0.7]]) + 0.8 with tf.variable_scope('model'): # 2.建立线性回归模型:1个权重,一个偏置 # 随机初始化一个权重和偏置的值,计算损失,然后通过梯度下降不断寻找最小损失 # 权重和偏置必须使用变量定义,因为它们的值是需要不断改变的,trainable参数指定是否随梯度下降进行优化,默认true weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0, name='w'), trainable=True) bias = tf.Variable(0.0, name='b') y_predict = tf.matmul(x, weight) + bias with tf.variable_scope('loss'): # 3.建立损失函数,square求平方,reduce_mean求平均值 loss = tf.reduce_mean(tf.square(y_true-y_predict)) with tf.variable_scope('optimizer'): # 4.梯度下降优化损失,0.1是学习率,minimize最小化损失 train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 收集tensor,losser是在TensorBoard后台显示的名字 tf.summary.scalar("losser", loss) tf.summary.histogram("w", weight) # 定义合并tensor的op,在sess中方便将其添加进事件中 merge = tf.summary.merge_all() # 定义对变量进行初始化的op init_op = tf.global_variables_initializer() # 定义一个保存模型的实例op saver = tf.train.Saver() # 通过会话运行程序 with tf.Session() as sess: # 初始化变量 sess.run(init_op) # 打印最先随机初始化的权重和偏置 print("随机初始化的参数权重:{},偏置:{}".format(weight.eval(), bias.eval())) # 建立事件文件 filewriter = tf.summary.FileWriter("./tmp/summary/test", graph=sess.graph) # 加载模型,覆盖模型当中的一开始随机初始化的参数,让模型接着从上次被打断的地方的参数继续进行 if os.path.exists("./tmp/ckpt/model/checkpoint"): saver.restore(sess, "./tmp/ckpt/model") # 循环运行优化 for i in range(aaa.max_step): sess.run(train_op) # 运行合并的merge op summ = sess.run(merge) # 将summ添加入事件中 filewriter.add_summary(summ, i) print("第{}次优化的参数权重:{},偏置:{}".format(i, weight.eval(), bias.eval())) # 保存模型,model保存模型的名字,一定要有 saver.save(sess, "./tmp/ckpt/model") if __name__ == "__main__": myregression()