练习一,线性函数模型建立
实例:使用法随机梯度下降法建立线性函数y=3*x+6
#coding=utf-8 from __future__ import print_function import os import tensorflow as tf from matplotlib import pyplot as plt import numpy as np #create data start x_data = np.random.rand(100).astype(dtype=np.float32) y_data = x_data * 3 + 6 #create data end #create tensorflow structure start Weights = tf.Variable(tf.random_uniform([1],-5,5)) Biases = tf.Variable(tf.ones([1])) y = Weights * x_data + Biases loss = tf.reduce_mean(tf.square(y-y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss=loss) init = tf.initialize_all_variables() #create tensorflow structure end #start training sess = tf.Session() sess.run(init) print("before training data is") print(sess.run(Weights), sess.run(Biases),"\n") for step in np.arange(300): if step % 20 == 0 : print(sess.run(Weights), sess.run(Biases)) sess.run(train) print("\nafter training data is") print(sess.run(Weights), sess.run(Biases)) sess.close()
显示结果如下
before training data is [-0.52837467] [ 1.] [-0.52837467] [ 1.] [ 2.74584365] [ 6.12300587] [ 2.92807698] [ 6.03480911] [ 2.97964644] [ 6.0098505] [ 2.99424028] [ 6.00278759] [ 2.99836993] [ 6.00078869] [ 2.99953914] [ 6.00022316] [ 2.99986959] [ 6.00006294] [ 2.99996328] [ 6.00001764] [ 2.99998927] [ 6.00000525] [ 2.9999969] [ 6.00000143] [ 2.99999809] [ 6.00000095] [ 2.99999809] [ 6.00000095] [ 2.99999809] [ 6.00000095] [ 2.99999809] [ 6.00000095] after training data is [ 2.99999809] [ 6.00000095]
如果想要将中间的变量结果保存下来,可以使用方法如下
storeFileName = "/tmp/modelvariable.val" saver = tf.train.Saver() saver.save(sess,storeFileName)
在下一次恢复时,就不需要初始化变量了,可以直接定义好变量后,使用恢复函数就可以将之前的变量参数恢复出来。具体如下
#coding=utf-8 import tensorflow as tf import numpy as np #restore variable from file Weights = tf.Variable(tf.random_uniform([1],-5,5)) Biases = tf.Variable(tf.ones([1])) storeFileName = "/tmp/modelvariable.val" saver = tf.train.Saver() sess = tf.Session() saver.restore(sess,storeFileName) print "already restore data from file" print sess.run(Weights),sess.run(Biases) sess.close()
学习过程中,难免出错。如果您在阅读过程中遇到不太明白,或者有疑问。欢迎指正...联系邮箱crazyCodeLove@163.com
如果觉得有用,想赞助一下请移步赞助页面:赞助一下