Tensorflow导入训练模型进行识别(附代码)
最近在做李宏毅的深度学习的作业,导入模型的时候,发现,我在导入模型进行预测时,需要重新手动构建网络进行检测,这样显得十分不“智能”。之前在比赛中一直是使用这种方法,但是由于当初比较忙,并没有深究这个问题。现在,学习了一下,发现使用Tensorflow 可以用两种方法进行预测。
首先,我们来讲一下,如何将如何加载模型:
在TensorFlow中,加载模型的方法:
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, 'path_of_model')
接下来,我们看一下加载模型并预测的方法
一、手动重写网络模型,进行检测
顾名思义,就是在加载自己训练的模型之前,重新手写一遍自己用来训练的网络,将图像的特征提取处理,然后得到feature map后传入模型进行预测,然我们来看看代码吧:
data = tf.placeholder(tf.float32, [None, 28*28])
label = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
image = tf.reshape(data, [-1, 28, 28, 1])
# conv1
W_conv1 = Weights([5, 5, 1, 6])
B_conv1 = Bias([6])
conv1 = tf.nn.conv2d(image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
relu_1 = tf.nn.relu(conv1+B_conv1)
pooling1 = tf.nn.max_pool(relu_1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
#conv2
W_conv2 = Weights([5, 5, 6, 16])
B_conv2 = Bias([16])
conv2 = tf.nn.conv2d(pooling1, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
relu_2 = tf.nn.relu(conv2+B_conv2)
pooling2 = tf.nn.max_pool(relu_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
flatten = tf.reshape(pooling2, [-1, 7*7*16])
#FCL1
W_FCL1 = Weights([7*7*16, 84])
B_FCL1 = Weights([84])
FCL1 = tf.matmul(flatten, W_FCL1) + B_FCL1
#FCL2
W_FCL2 = Weights([84, 10])
B_FCL2 = Weights([10])
FCL2 = tf.matmul(FCL1, W_FCL2) + B_FCL2
prediction = tf.nn.softmax(FCL2)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './model/mnist_cnn.ckpt')#加载模型
print(sess.run(res, feed_dict={data: pic, keep_prob:1.0}))
这样就可以输出预测的值啦。虽然说,手动重写网络结构比较简单直接(毕竟可以直接复制训练代码),但是这样未免显得不太智能。事实上,TensorFlow在训练时,已经给我们,保存了所有变量和网络结构。我们先来看一下训练时保存的文件:
各文件功能:
1、meta文件保存的是图结构,也就是网络模型,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。
2、ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:
mnist_cnn.ckpt.data-00000-of-00001
mnist_cnn.index
3、checkpoint文件
我们还可以看,checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model
好了,下面我们来说第二个方法
二、使用meta文件导入网络结构
直接看一下代码:
with tf.Session() as sess:
#载入meta文件
saver = tf.train.import_meta_graph('./model/mnist_cnn.ckpt.meta')
#载入最近的一次保存的模型文件
saver.restore(sess, tf.train.latest_checkpoint("./model/"))
#建立图
graph = tf.get_default_graph()
#初始化所有的变量
sess.run(tf.global_variables_initializer())
#获取网络中的变量
X = graph.get_tensor_by_name('data:0')
keep_prob = graph.get_tensor_by_name('keep_prob:0')
result = graph.get_tensor_by_name('prediction:0')
#这上面的‘data’、‘keep_prob’、‘prediction’都是我们在训练时定义的变量名称
print(sess.run(result, feed_dict={X: pic, keep_prob: 1.0}))
我们来讲一下最重要的两个函数:
#这里是获取变量名
tf.get_default_graph().get_tensor_by_name('tensor_name')
#这里是获取对应的operation
tf.get_default_graph().get_tensor_by_name('op_name')
其中,get_tensor_by_name('tensor_name')中的tensor_name的组成是:定义的变量名加上数字
也就是 data:num
num指的是第几个变量,这个是必须的。因为当只有变量名时,tensoflow会认为是operation
可以看一下去除后的报错信息:
ValueError: The name 'data' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".
翻译:值错误,‘data’是指向一个operation,而不是一个Tensor。Tensor名字必须有是<op_name>:<output_index>这种格式
那么,关键来了,众所周知,我们在使用tensorflow的时候是允许同时对两个变量使用同一个变量名的。
例如:
data = tf.placeholder(tf.float32, [None, 28*28], name='data')
label = tf.placeholder(tf.float32, [None, 10], name='data')
所以,output_index就变得十分重要,他表示是同名变量的第几个。
还有就是,当我们没有自己定义变量名时,会默认定义为placeholder
data = tf.placeholder(tf.float32, [None, 28*28])
label = tf.placeholder(tf.float32, [None, 10])
那么我们在使用时,就要用一下方式:
#获得data
tf.get_default_graph().get_tensor_by_name('placeholder:0')
#获得label
tf.get_default_graph().get_tensor_by_name('placeholder:1')
好了,以上就是我个人的一点点总结,希望对大家有帮助
完整代码可以到我的GitHub获取:https://github.com/Dylanin1999/DL_HW
代码实现了mnist集的训练和检测
要是觉得有帮助的朋友,麻烦给个star呗
posted on 2022-08-13 16:15 DylanYeung 阅读(180) 评论(0) 编辑 收藏 举报