tensorflow中关于vgg16的项目

转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html

tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。

下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。

github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,

如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。

#coding:utf-8


import skimage
import skimage.io
import skimage.transform
a=skimage.io.imread('cat.jpg')
import PIL
import numpy as np
import tensorflow as tf
synset = [l.strip() for l in open('/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt').readlines()]

def load_image(path):
  # load image
  img = skimage.io.imread(path)
  #img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
  #img=np.array(PIL.Image.open(path))
  #imgx=np.array(img)  
  #print type(imgx),imgx.shape
  img = img/ 255.0
  assert (0 <= img).all() and (img <= 1.0).all()
  #print "Original Image Shape: ", img.shape
  # we crop image from center
  short_edge = min(img.shape[:2])
  yy = int((img.shape[0] - short_edge) / 2)
  xx = int((img.shape[1] - short_edge) / 2)
  crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
  # resize to 224, 224
  resized_img = skimage.transform.resize(crop_img, (224, 224))
  return resized_img
  
# returns the top1 string
def print_prob(prob):
  #print prob
  print "prob shape", prob.shape
  pred = np.argsort(prob)[::-1]
  # Get top1 label
  top1 = synset[pred[0]]
  #print "Top1: ", top1
  # Get top5 label
  top5 = [synset[pred[i]] for i in range(5)]
  #print "Top5: ", top5
  return top1

print u'加载模型文件'
with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode='rb') as f:
  fileContent = f.read()
  
print u'创建图'
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"

graph = tf.get_default_graph()
print u'加载图片'
#img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
#cat = load_image(path)
print u'进入sess执行'

sess=tf.Session()
result=[]
for i in ['cat.jpg','airplane.jpg','zebra.jpg','pig.jpg','12.jpg','23.jpg']:
  img=load_image('pic/'+i)
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = img.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u'开始执行'
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)
  print u'输出结果'
  #print_prob(prob[0])
  result.append(print_prob(prob[0]))


print result
sess.close()


'''
with tf.Session() as sess:
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = cat.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u'开始执行'
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)

print u'输出结果'
print_prob(prob[0])

'''

  

posted on 2016-06-30 16:49  徐长卿学数据分析  阅读(7842)  评论(0编辑  收藏  举报