tensorflow中关于vgg16的项目
转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html
tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。
下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。
github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,
如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。
#coding:utf-8
import skimage
import skimage.io
import skimage.transform
a=skimage.io.imread('cat.jpg')
import PIL
import numpy as np
import tensorflow as tf
synset = [l.strip() for l in open('/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt').readlines()]
def load_image(path):
# load image
img = skimage.io.imread(path)
#img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
#img=np.array(PIL.Image.open(path))
#imgx=np.array(img)
#print type(imgx),imgx.shape
img = img/ 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
# returns the top1 string
def print_prob(prob):
#print prob
print "prob shape", prob.shape
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
#print "Top1: ", top1
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
#print "Top5: ", top5
return top1
print u'加载模型文件'
with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode='rb') as f:
fileContent = f.read()
print u'创建图'
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
print u'加载图片'
#img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
#cat = load_image(path)
print u'进入sess执行'
sess=tf.Session()
result=[]
for i in ['cat.jpg','airplane.jpg','zebra.jpg','pig.jpg','12.jpg','23.jpg']:
img=load_image('pic/'+i)
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = img.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
print u'开始执行'
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print u'输出结果'
#print_prob(prob[0])
result.append(print_prob(prob[0]))
print result
sess.close()
'''
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = cat.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
print u'开始执行'
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print u'输出结果'
print_prob(prob[0])
'''