自建神经网络与结果可视化

来自周莫烦Tensorflow教学视频,主要内容有添加神经层,和训练结果可视化。

youtube地址: https://www.youtube.com/watch?v=nhn8B0pM9ls&list=PLXO45tsB95cKI5AIlf5TxxFPzb-0zeVZ8&index=16

 1 import tensorflow as tf
 2 import numpy as np
 3 import os
 4 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 5 import matplotlib.pyplot as plt
 6 def add_layer(inputs, in_size, out_size, activation_function = None):
 7     #矩阵
 8     Weights = tf.Variable(tf.random_normal([in_size,out_size]))
 9     bias = tf.Variable(tf.zeros([1,out_size]) + 0.1)
10     Wx_plus_b = tf.matmul(tf.cast(inputs,tf.float32),Weights) + bias
11     if activation_function is None:
12         outputs = Wx_plus_b
13     else:
14         outputs = activation_function(Wx_plus_b)
15     return outputs
16 
17 if 'session' in locals() and session is not None:
18     print('Close interactive session')
19     session.close()
20 
21 x_data = np.linspace(-1,1,300)[:,np.newaxis]
22 noise  = np.random.normal(0,0.05,x_data.shape)
23 y_data = np.square(x_data) - 0.5 + noise
24 xs = tf.placeholder(tf.float32,[None,1])
25 ys = tf.placeholder(tf.float32,[None,1])
26 
27 l1 = add_layer(x_data,1,10,activation_function=tf.nn.relu)
28 prediction = add_layer(l1,10,1,activation_function=None)
29 
30 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
31 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
32 
33 init = tf.global_variables_initializer()
34 sess = tf.Session()
35 sess.run(init)
36 
37 fig = plt.figure()
38 ax = fig.add_subplot(1,1,1) #连续性,编号
39 ax.scatter(x_data,y_data)
40 plt.ion() #不暂停主程序
41 plt.show()
42 
43 for i in range(1000):
44     sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
45     if i % 50 == 0:
46         try:
47             ax.lines.remove(lines[0])
48         except Exception:
49             pass
50         #print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
51         prediction_value = sess.run(prediction, feed_dict={xs:x_data})
52         lines = ax.plot(x_data,prediction_value,'r-',lw=5)
53         plt.pause(0.1)

 结果显示:

 

特别注意的是其中有三行代码为:

1 if 'session' in locals() and session is not None:
2     print('Close interactive session')
3     session.close()

是因为出现了“InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(300, 1), b.shape=(1, 10), m=300, n=10, k=1”的错误,stackoverflow上给的解答大致意思为session已在其他进程中出现,所以先关闭session.

 

posted @ 2017-09-26 14:23  牧牛子  阅读(719)  评论(0编辑  收藏  举报