有事没事领个红包

练习一,线性函数模型建立

实例:使用法随机梯度下降法建立线性函数y=3*x+6

#coding=utf-8
from __future__ import print_function
import os
import tensorflow as tf

from matplotlib import pyplot as plt
import numpy as np



#create data start
x_data = np.random.rand(100).astype(dtype=np.float32)
y_data = x_data * 3 + 6
#create data end

#create tensorflow structure start
Weights = tf.Variable(tf.random_uniform([1],-5,5))
Biases = tf.Variable(tf.ones([1]))

y = Weights * x_data + Biases

loss = tf.reduce_mean(tf.square(y-y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss=loss)

init = tf.initialize_all_variables()
#create tensorflow structure end


#start training
sess = tf.Session()
sess.run(init)

print("before training data is")
print(sess.run(Weights), sess.run(Biases),"\n")
for step in np.arange(300):
    if step % 20 == 0 :
        print(sess.run(Weights), sess.run(Biases))
    sess.run(train)

print("\nafter training data is")
print(sess.run(Weights), sess.run(Biases))

sess.close()

显示结果如下

before training data is
[-0.52837467] [ 1.] 

[-0.52837467] [ 1.]
[ 2.74584365] [ 6.12300587]
[ 2.92807698] [ 6.03480911]
[ 2.97964644] [ 6.0098505]
[ 2.99424028] [ 6.00278759]
[ 2.99836993] [ 6.00078869]
[ 2.99953914] [ 6.00022316]
[ 2.99986959] [ 6.00006294]
[ 2.99996328] [ 6.00001764]
[ 2.99998927] [ 6.00000525]
[ 2.9999969] [ 6.00000143]
[ 2.99999809] [ 6.00000095]
[ 2.99999809] [ 6.00000095]
[ 2.99999809] [ 6.00000095]
[ 2.99999809] [ 6.00000095]

after training data is
[ 2.99999809] [ 6.00000095]

 

 

如果想要将中间的变量结果保存下来,可以使用方法如下

storeFileName = "/tmp/modelvariable.val"

saver = tf.train.Saver()
saver.save(sess,storeFileName)

在下一次恢复时,就不需要初始化变量了,可以直接定义好变量后,使用恢复函数就可以将之前的变量参数恢复出来。具体如下

#coding=utf-8

import tensorflow as tf
import numpy as np

#restore variable from file
Weights = tf.Variable(tf.random_uniform([1],-5,5))
Biases = tf.Variable(tf.ones([1]))

storeFileName = "/tmp/modelvariable.val"

saver = tf.train.Saver()

sess = tf.Session()
saver.restore(sess,storeFileName)


print "already restore data from file"
print sess.run(Weights),sess.run(Biases)


sess.close()

 

posted @ 2016-11-11 10:12  crazyCodeLove  阅读(372)  评论(0编辑  收藏  举报