接上文MNIST的手写数据训练与识别(1)完成MNIST基本的神经网络的搭建,本文主要就介绍如何保存该神经网络和如何使用该保存的神经网络进行识别,直接上代码

保存神经网络

 1 import tensorflow as tf
 2 import numpy as np
 3 
 4 # tensorflow自带了MNIST数据集
 5 from tensorflow.examples.tutorials.mnist import input_data
 6 
 7 # 下载mnist数据集
 8 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
 9 # 数字(label)只能是0-9,神经网络使用10个出口节点就可以编码表示0-9;
10 #  1 -> [0,1.0,0,0,0,0,0,0,0]   one_hot表示只有一个出口节点是hot
11 #  2 -> [0,0.1,0,0,0,0,0,0,0]
12 #  5 -> [0,0,0,0,0,1.0,0,0,0]
13 #  /tmp是macOS的临时目录,重启系统数据丢失; Linux的临时目录也是/tmp
14 
15 # 定义每个层有多少'神经元''
16 n_input_layer = 28 * 28  # 输入层
17 
18 n_layer_1 = 500  # hide layer
19 n_layer_2 = 1000  # hide layer
20 n_layer_3 = 300  # hide layer(隐藏层)听着很神秘,其实就是除输入输出层外的中间层
21 
22 n_output_layer = 10  # 输出层
23 """
24 层数的选择:线性数据使用1层,非线性数据使用2册, 超级非线性使用3+册。层数/神经元过多会导致过拟合
25 """
26 
27 
28 # 定义待训练的神经网络(feedforward)
29 def nn(input_layer):
30    layer_1 = tf.layers.dense(input_layer, n_layer_1, tf.nn.relu)
31    layer_2 = tf.layers.dense(layer_1, n_layer_2, tf.nn.relu)
32    layer_3 = tf.layers.dense(layer_2, n_layer_3, tf.nn.relu)
33    layer_output = tf.layers.dense(layer_3, n_output_layer)
34 
35    return layer_output
36 
37 
38 # 每次使用100条数据进行训练
39 batch_size = 100
40 
41 X = tf.placeholder('float', [None, 28 * 28], name="input_image")
42 # [None, 28*28]代表数据数据的高和宽(矩阵),好处是如果数据不符合宽高,tensorflow会报错,不指定也可以。
43 Y = tf.placeholder('float')
44 
45 
46 # 使用数据训练神经网络
47 def train(X, Y):
48     predict = nn(X)
49     cost_func = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=predict))
50     optimizer = tf.train.AdamOptimizer().minimize(cost_func)  # learning rate 默认 0.001
51 
52     epochs = 13
53     saver = tf.train.Saver()
54     with tf.Session() as session:
55         session.run(tf.global_variables_initializer())
56         #epoch_loss = 0
57         for epoch in range(epochs):
58             epoch_loss = 0
59             for i in range(int(mnist.train.num_examples / batch_size)):
60                 image, label = mnist.train.next_batch(batch_size)
61                 _, c = session.run([optimizer, cost_func], feed_dict={X: image, Y: label})
62                 epoch_loss += c
63             print(epoch, ' : ', epoch_loss)
64 
65         argmax_predict = tf.argmax(predict, 1, name="argmax_predict")
66         #correct = tf.equal(tf.argmax(predict, 1), tf.argmax(Y, 1))
67         #accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
68         #print('准确率: ', accuracy.eval({X: mnist.test.images, Y: mnist.test.labels}))
69         #print('准确率: ', session.run(accuracy, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))
70 
71         save_path = saver.save(session, "my_net/save_net.ckpt")
72 
73 train(X, Y)
View Code

神经网络保存通过tensorflow里面saver = tf.train.Saver()即可简单实现神经网络的保存,对X“变量‘’定义时附加name属性,同样对于“方法”argmax_predict附加name属性(这里为了结果更为清晰,使用tf.argmax(predict)

通过

save_path = saver.save(session, "my_net/save_net.ckpt")

实现神经网络参数的保存,完成后会在my_net目录生成参数文件

 

使用神经网络

 1 import tensorflow as tf
 2 import numpy as np
 3 
 4 # tensorflow自带了MNIST数据集
 5 from tensorflow.examples.tutorials.mnist import input_data
 6 
 7 # 下载mnist数据集
 8 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
 9 
10 
11 
12 sess=tf.Session()
13 #First let's load meta graph and restore weights
14 saver = tf.train.import_meta_graph('my_net/save_net.ckpt.meta')
15 saver.restore(sess,tf.train.latest_checkpoint('./my_net'))
16 
17 graph = tf.get_default_graph()
18 input_image = graph.get_tensor_by_name("input_image:0")
19 # feed_dict ={w1:13.0,w2:17.0}
20 
21 #Now, access the op that you want to run.
22 argmax_predict = graph.get_tensor_by_name("argmax_predict:0")
23 
24 for i in range(10):
25     argmax_Y = tf.argmax(mnist.test.labels[i].reshape(1, 10), 1)
26     print('预测值: 实际值', sess.run([argmax_predict,argmax_Y],feed_dict={input_image: mnist.test.images[i].reshape(1,28*28)}))
View Code

获取到”方法“argmax_predict和“参数“input_image,通过session.run(argmax_predict)传入待识别的图片即可得到结果,实例中输出测试数据集中的十幅图片得到结果

预测值: 实际值 [array([7]), array([7])]
预测值: 实际值 [array([2]), array([2])]
预测值: 实际值 [array([1]), array([1])]
预测值: 实际值 [array([0]), array([0])]
预测值: 实际值 [array([4]), array([4])]
预测值: 实际值 [array([1]), array([1])]
预测值: 实际值 [array([4]), array([4])]
预测值: 实际值 [array([5]), array([9])]
预测值: 实际值 [array([6]), array([5])]
预测值: 实际值 [array([9]), array([9])]