用一个Inception v3 架构模型实现简单的迁移学习(译:../tensorflow/tensorflow/examples/image_retraining/retrain.py)

  1 # Copyright 2015 Google Inc. All Rights Reserved.
  2 #
  3 # Licensed under the Apache License, Version 2.0 (the "License");
  4 # you may not use this file except in compliance with the License.
  5 # You may obtain a copy of the License at
  6 #
  7 #     http://www.apache.org/licenses/LICENSE-2.0
  8 #
  9 # Unless required by applicable law or agreed to in writing, software
 10 # distributed under the License is distributed on an "AS IS" BASIS,
 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12 # See the License for the specific language governing permissions and
 13 # limitations under the License.
 14 # ==============================================================================
 15 """用一个Inception v3 架构模型实现简单的迁移学习。
 16     这个例子展示了如何实现一个Inception v3架构模型,在ImageNet图片的训练,并且训练一层可以被别的物类识别的新顶层。
 17     顶层,对于每张图片可以接受输入2048-维向量。在此之上,我们训练一个软件层。假定软件层包含N个labels,这相应的需要
 18 学习N+2048×N个模型参数,还有相应的学习偏差和权重。
 19     这里有一个例子,假设你有一个文件夹包含类名子文件夹,每一个子文件夹标签包含照片。例子文件夹flower_photos应该是
 20 这样的结构:
 21 ~/flower_photos/daisy/photo1.jpg
 22 ~/flower_photos/daisy/photo2.jpg
 23 ...
 24 ~/flower_photos/rose/anotherphoto77.jpg
 25 ...
 26 ~/flower_photos/sunflower/somepicture.jpg
 27     子文件夹的命名是好非常重要的,一旦他们被定义,那么类别的标签就会被应用到该子文件夹里的每张照片,但是照片本身的名字可以
 28 随便。如果你准备好了照片,你可以开始进行训练,使用这些命令:
 29 bazel build third_party/tensorflow/examples/image_retraining:retrain && \
 30 bazel-bin/third_party/tensorflow/examples/image_retraining/retrain \
 31 --image_dir ~/flower_photos
 32     你可以替换image_dir参数,使用任意的文件夹,只要包含子文件夹,且子文件夹里包含照片。对应每张照片的标签是子文件夹的名字。
 33     这样产生的新模型文件可以被任意的Tensorflow程序加载和执行,例如label_iamge的代码。
 34 """
 35 from __future__ import absolute_import
 36 from __future__ import division
 37 from __future__ import print_function
 38 
 39 from datetime import datetime
 40 import glob
 41 import hashlib
 42 import os.path
 43 import random
 44 import re
 45 import sys
 46 import tarfile
 47 
 48 import numpy as np
 49 from six.moves import urllib
 50 import tensorflow as tf
 51 
 52 from tensorflow.python.client import graph_util
 53 from tensorflow.python.framework import tensor_shape
 54 from tensorflow.python.platform import gfile
 55 
 56 
 57 FLAGS = tf.app.flags.FLAGS
 58 
 59 # 输入和输出文件标志
 60 tf.app.flags.DEFINE_string('image_dir', '',
 61                            """图片文件夹的路径。""")
 62 tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb',
 63                            """训练图表保存到哪里?""")
 64 tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt',
 65                            """训练图表的标签保存到哪里?""")
 66 
 67 # 详细的训练配置
 68 tf.app.flags.DEFINE_integer('how_many_training_steps', 4000,
 69                             """在结束之前,需要训练多少步?""")
 70 tf.app.flags.DEFINE_float('learning_rate', 0.01,
 71                           """在训练的时候设置多大的学习率?""")
 72 tf.app.flags.DEFINE_integer(
 73     'testing_percentage', 10,
 74     """图片用于测试的百分比""")
 75 tf.app.flags.DEFINE_integer(
 76     'validation_percentage', 10,
 77     """图片用于检定的百分比""")
 78 tf.app.flags.DEFINE_integer('eval_step_interval', 10,
 79                             """评估培训结果的频率?""")
 80 tf.app.flags.DEFINE_integer('train_batch_size', 100,
 81                             """一次训练多少张照片?""")
 82 tf.app.flags.DEFINE_integer('test_batch_size', 500,
 83                             """一次测试多少张照片?"""
 84                             """这个测试集只使用很少来验证"""
 85                             """模型的整体精度。""")
 86 tf.app.flags.DEFINE_integer(
 87     'validation_batch_size', 100,
 88     """有多少图片在一个评估批量使用。这个验证集"""
 89     """被使用的频率比测试集多, 这也是一个早期的指标"""
 90     """模型有多精确在训练期间。""")
 91 
 92 # 文件系统cache所在目录
 93 tf.app.flags.DEFINE_string('model_dir', '/tmp/imagenet',
 94                            """classify_image_graph_def.pb,"""
 95                            """imagenet_synset_to_human_label_map.txt, and """
 96                            """imagenet_2012_challenge_label_map_proto.pbtxt的路径。""")
 97 tf.app.flags.DEFINE_string(
 98     'bottleneck_dir', '/tmp/bottleneck',
 99     """cache bottleneck 层作为值的文件集。""")
100 tf.app.flags.DEFINE_string('final_tensor_name', 'final_result',
101                            """分类输出层的名称"""
102                            """在重新训练时。""")
103 
104 # 控制扭曲参数在训练期间
105 tf.app.flags.DEFINE_boolean(
106     'flip_left_right', False,
107     """是否随机对半水平翻转训练图片。""")
108 tf.app.flags.DEFINE_integer(
109     'random_crop', 0,
110     """A percentage determining how much of a margin to randomly crop off the"""
111     """ training images.""")
112 tf.app.flags.DEFINE_integer(
113     'random_scale', 0,
114     """A percentage determining how much to randomly scale up the size of the"""
115     """ training images by.""")
116 tf.app.flags.DEFINE_integer(
117     'random_brightness', 0,
118     """A percentage determining how much to randomly multiply the training"""
119     """ image input pixels up or down by.""")
120 
121 # 这些都是参数绑定到特定的模型架构
122 # 我们使用的Inceptionv3。这些包括张量名称及其大小。
123 # 如果你想适应这个脚本使用到另一个模型,您将需要更新这些反映网络中使用的值。
124 # pylint: disable=line-too-long
125 DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
126 # pylint: enable=line-too-long
127 BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
128 BOTTLENECK_TENSOR_SIZE = 2048
129 MODEL_INPUT_WIDTH = 299
130 MODEL_INPUT_HEIGHT = 299
131 MODEL_INPUT_DEPTH = 3
132 JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
133 RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
134 
135 
136 def create_image_lists(image_dir, testing_percentage, validation_percentage):
137     """建立训练图像的文件系统的列表。
138     分析了图像中的子文件夹目录,分裂成稳定的训练、测试和验证集,并返回一个数据结构来描述图像的列表为每个标签及其路径。
139     Args:
140     image_dir:字符串路径文件夹包含图片的子文件夹。
141     testing_percentage:图像的整数百分比保留给测试。
142     validation_percentage:图像的整数百分比保留给校正。
143 
144     Return:一个文件夹包含每个子文件夹入口标签,伴随着图片被分成training,testing和validation集
145 """
146   if not gfile.Exists(image_dir):
147     print("Image directory '" + image_dir + "' not found.")
148     return None
149   result = {}
150   sub_dirs = [x[0] for x in os.walk(image_dir)]
151   # The root directory comes first, so skip it.
152   is_root_dir = True
153   for sub_dir in sub_dirs:
154     if is_root_dir:
155       is_root_dir = False
156       continue
157     extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
158     file_list = []
159     dir_name = os.path.basename(sub_dir)
160     if dir_name == image_dir:
161       continue
162     print("Looking for images in '" + dir_name + "'")
163     for extension in extensions:
164       file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
165       file_list.extend(glob.glob(file_glob))
166     if not file_list:
167       print('No files found')
168       continue
169     if len(file_list) < 20:
170       print('WARNING: Folder has less than 20 images, which may cause issues.')
171     label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
172     training_images = []
173     testing_images = []
174     validation_images = []
175     for file_name in file_list:
176       base_name = os.path.basename(file_name)
177     #我们想要忽略任何东西在'_nohash_'的文件名称,当决定将把一个图像,
178     数据集的创造者的分组照片接近彼此的变化。例如这是用于植物病害数据设置为组的多个图片相同的叶子。
179       hash_name = re.sub(r'_nohash_.*$', '', file_name)
180     #这看起来有点不可思议,但是我们需要决定这文件是否应该被写入training,testing或者validation集,并且
181         我们希望保持现有的文件在同一组即使随后制造更多的文件。
182         我们需要一个稳定的方式决定基于文件名本身,所以我们做一个散列,然后用它来生成一个概率值,我们使用分配它。
183       hash_name_hashed = hashlib.sha1(hash_name.encode('utf-8')).hexdigest()
184       percentage_hash = (int(hash_name_hashed, 16) % (65536)) * (100 / 65535.0)
185       if percentage_hash < validation_percentage:
186         validation_images.append(base_name)
187       elif percentage_hash < (testing_percentage + validation_percentage):
188         testing_images.append(base_name)
189       else:
190         training_images.append(base_name)
191     result[label_name] = {
192         'dir': dir_name,
193         'training': training_images,
194         'testing': testing_images,
195         'validation': validation_images,
196     }
197   return result
198 
199 def get_image_path(image_lists, label_name, index, image_dir, category):
200     """返回一个路径为一个标签图像给定索引。
201 
202     Args:
203     image_lists:每一个标签的训练图片字典。
204     label_name:我们想要一个图像标签字符串。
205     index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。
206     image_dir:根目录包含子目录字符,里面包含训练照片。
207     category:名称字符串设置将图像从培训,测试,或验证获取。
208     Return:
209     文件系统路径字符串满足图像请求的参数。
210 """
211 
212   if label_name not in image_lists:
213     tf.logging.fatal('Label does not exist %s.', label_name)
214   label_lists = image_lists[label_name]
215   if category not in label_lists:
216     tf.logging.fatal('Category does not exist %s.', category)
217   category_list = label_lists[category]
218   if not category_list:
219     tf.logging.fatal('Category has no images - %s.', category)
220   mod_index = index % len(category_list)
221   base_name = category_list[mod_index]
222   sub_dir = label_lists['dir']
223   full_path = os.path.join(image_dir, sub_dir, base_name)
224   return full_path
225 
226 def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
227                         category):
228     """在给定的索引标签,返回一个路径给一个bottleneck文件。
229     Asgs:
230     image_lists:每一个标签的训练图片字典。
231     label_name:我们想要一个图像标签字符串。
232     index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。
233         bottleneck:bottleneck值的缓存字符文件夹
234     image_dir:根目录包含子目录字符,里面包含训练照片。
235     category:名称字符串设置将图像从培训,测试,或验证获取。
236     Returns:
237     符合要求的参数照片的文件系统路径字符
238     """
239   return get_image_path(image_lists, label_name, index, bottleneck_dir,
240                         category) + '.txt'
241 
242 def create_inception_graph():
243     """从保存的GraphDef文件创建一个图形,然后返回一个图形对象。
244     Returns:
245     包含已经训练过的Inception网络图形,和各种各样的tensors,我们将要控制的。
246 """
247   with tf.Session() as sess:
248     model_filename = os.path.join(
249         FLAGS.model_dir, 'classify_image_graph_def.pb')
250     with gfile.FastGFile(model_filename, 'rb') as f:
251       graph_def = tf.GraphDef()
252       graph_def.ParseFromString(f.read())
253       bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
254           tf.import_graph_def(graph_def, name='', return_elements=[
255               BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
256               RESIZED_INPUT_TENSOR_NAME]))
257   return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
258 
259 
260 def run_bottleneck_on_image(sess, image_data, image_data_tensor,
261                             bottleneck_tensor):
262     """在一张照片运行推断提取'bottleneck'简要层
263     Args:
264     sess: 当前活跃的Tensorflow Session
265     image_data: 图片数据的Numpy矩阵
266     image_data_tensor: 在图表中的输入数据层
267     bottleneck_tensor: 在最后软件层之前的层
268 
269   Returns:
270     bottleneck值的Numpy矩阵。
271   """
272   bottleneck_values = sess.run(
273       bottleneck_tensor,
274       {image_data_tensor: image_data})
275   bottleneck_values = np.squeeze(bottleneck_values)
276   return bottleneck_values
277 
278 
279 def maybe_download_and_extract():
280     """下载和解压模型的tar文件
281     如果在重新训练模型之前,我们使用的不存在,这个函数会从Tensorflow.org网站下载,然后解压它到一个文件夹。
282 """
283   dest_directory = FLAGS.model_dir
284   if not os.path.exists(dest_directory):
285     os.makedirs(dest_directory)
286   filename = DATA_URL.split('/')[-1]
287   filepath = os.path.join(dest_directory, filename)
288   if not os.path.exists(filepath):
289 
290     def _progress(count, block_size, total_size):
291       sys.stdout.write('\r>> Downloading %s %.1f%%' %
292                        (filename,
293                         float(count * block_size) / float(total_size) * 100.0))
294       sys.stdout.flush()
295 
296     filepath, _ = urllib.request.urlretrieve(DATA_URL,
297                                              filepath,
298                                              reporthook=_progress)
299     print()
300     statinfo = os.stat(filepath)
301     print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
302   tarfile.open(filepath, 'r:gz').extractall(dest_directory)
303 
304 
305 def ensure_dir_exists(dir_name):
306   """确保文件夹在硬盘里存在。
307 
308   Args:
309     dir_name: 我们想要创建的路径字符文件夹
310   """
311   if not os.path.exists(dir_name):
312     os.makedirs(dir_name)
313 
314 
315 def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
316                              category, bottleneck_dir, jpeg_data_tensor,
317                              bottleneck_tensor):
318   """重新找回或者计算一张照片的bottleneck值
319 
320   如果在硬盘上存在一个cached的版本,放回它,否则计算值并保存到硬盘,供下次使用。
321 
322   Args:
323     sess: 当前活跃的Tensorflow Session
324     image_lists: 每一个标签的训练图片字典。
325     label_name: 我们想要一个图像标签字符串。
326     index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。。
327     category:名称字符串设置将图像从培训,测试,或验证获取。
328     bottleneck_dir: 包含bottleneck值的缓存文件夹字符。
329     jpeg_data_tensor: 获取加载jpeg数据进入tensor。
330     bottleneck_tensor: bottleneck值的输出tensor。
331 
332   Returns:
333     为一张图片而长生的Numpy矩阵值,在bottleneck层。
334   """
335   label_lists = image_lists[label_name]
336   sub_dir = label_lists['dir']
337   sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
338   ensure_dir_exists(sub_dir_path)
339   bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
340                                         bottleneck_dir, category)
341   if not os.path.exists(bottleneck_path):
342     print('Creating bottleneck at ' + bottleneck_path)
343     image_path = get_image_path(image_lists, label_name, index, image_dir,
344                                 category)
345     if not gfile.Exists(image_path):
346       tf.logging.fatal('File does not exist %s', image_path)
347     image_data = gfile.FastGFile(image_path, 'rb').read()
348     bottleneck_values = run_bottleneck_on_image(sess, image_data,
349                                                 jpeg_data_tensor,
350                                                 bottleneck_tensor)
351     bottleneck_string = ','.join(str(x) for x in bottleneck_values)
352     with open(bottleneck_path, 'w') as bottleneck_file:
353       bottleneck_file.write(bottleneck_string)
354 
355   with open(bottleneck_path, 'r') as bottleneck_file:
356     bottleneck_string = bottleneck_file.read()
357   bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
358   return bottleneck_values
359 
360 
361 def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
362                       jpeg_data_tensor, bottleneck_tensor):
363   """保证所有的training,testing和validation的bottleneck都缓存完毕。
364   因为我们可能会读取同一张照片很多次(如果在训练的时候没有运用扭曲)它可以加快速度很多,如果我们为每个图像计算bottleneck值一次在预处理层
365 然后在训练的时候反复地只读这些缓存值。这我们就完成了所有我们发现的所有照片,计算这些值,然后保存他们。
366 
367   Args:
368     sess: 当前活跃的Tensorflow Session
369     image_lists: 每一个标签的训练图片字典。
370     image_dir:根目录包含子目录字符,里面包含训练照片。
371     bottleneck_dir: 包含bottleneck值的缓存文件夹字符。
372     jpeg_data_tensor: 获取加载jpeg数据进入tensor。
373     bottleneck_tensor: 倒数第二输出层的图。
374 
375   Returns:
376 377   """
378   how_many_bottlenecks = 0
379   ensure_dir_exists(bottleneck_dir)
380   for label_name, label_lists in image_lists.items():
381     for category in ['training', 'testing', 'validation']:
382       category_list = label_lists[category]
383       for index, unused_base_name in enumerate(category_list):
384         get_or_create_bottleneck(sess, image_lists, label_name, index,
385                                  image_dir, category, bottleneck_dir,
386                                  jpeg_data_tensor, bottleneck_tensor)
387         how_many_bottlenecks += 1
388         if how_many_bottlenecks % 100 == 0:
389           print(str(how_many_bottlenecks) + ' bottleneck files created.')
390 
391 
392 def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
393                                   bottleneck_dir, image_dir, jpeg_data_tensor,
394                                   bottleneck_tensor):
395   """重新得到为缓存图片获得bottleneck值
396 
397   如果没有应用扭曲,这函数能够重新得到缓存bottleneck的值,直接从硬盘。它挑选一组随机的照片
398 从特定类别。
399 
400   Args:
401     sess: 当前活跃的Tensorflow Session
402     image_lists: 每一个标签的训练图片字典。
403     how_many: 返回bottleneck值的数量
404     category: 名称字符串设置将图像从培训,测试,或验证获取。
405     bottleneck_dir: 包含bottleneck值的缓存文件夹字符。
406     image_dir: 根目录包含子目录字符,里面包含训练照片。
407     jpeg_data_tensor: 获取加载jpeg数据进入tensor。
408     bottleneck_tensor: 倒数第二输出层的CNN图。
409 
410   Returns:
411     bottleneck列表和他们相应的ground truthes.
412   """
413   class_count = len(image_lists.keys())
414   bottlenecks = []
415   ground_truthes = []
416   for unused_i in range(how_many):
417     label_index = random.randrange(class_count)
418     label_name = list(image_lists.keys())[label_index]
419     image_index = random.randrange(65536)
420     bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
421                                           image_index, image_dir, category,
422                                           bottleneck_dir, jpeg_data_tensor,
423                                           bottleneck_tensor)
424     ground_truth = np.zeros(class_count, dtype=np.float32)
425     ground_truth[label_index] = 1.0
426     bottlenecks.append(bottleneck)
427     ground_truthes.append(ground_truth)
428   return bottlenecks, ground_truthes
429 
430 
431 def get_random_distorted_bottlenecks(
432     sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
433     distorted_image, resized_input_tensor, bottleneck_tensor):
434   """扭曲之后,为训练的照片重新获得bottleneck的值。
435 
436   如果我们训练使用扭曲,例如修剪,缩放,翻转,我们不得不重新计算整个模型的每张照片,所以我们不能用缓存里bottleneck的
437 值。相反,我们为请求类别寻找随机照片,通过扭曲的图表来执行,然后整个图表为每个获得bottleneck结果。
438 
439   Args:
440     sess: 当前活跃的Tensorflow Session
441     image_lists: 每一个标签的训练图片字典。
442     how_many: 返回bottleneck值的数量
443     category: 名称字符串设置将图像从培训,测试,或验证获取。
444     image_dir: 根目录包含子目录字符,里面包含训练照片。
445     input_jpeg_tensor: 我们输入图片数据的输入层。
446     distorted_image: 输出节点的扭曲图表
447     resized_input_tensor: 输入节点的识别图表。
448     bottleneck_tensor: 倒数第二输出层的CNN图。
449 
450   Returns:
451     bottleneck列表和他们相应的ground truthes.
452   """
453   class_count = len(image_lists.keys())
454   bottlenecks = []
455   ground_truthes = []
456   for unused_i in range(how_many):
457     label_index = random.randrange(class_count)
458     label_name = image_lists.keys()[label_index]
459     image_index = random.randrange(65536)
460     image_path = get_image_path(image_lists, label_name, image_index, image_dir,
461                                 category)
462     if not gfile.Exists(image_path):
463       tf.logging.fatal('File does not exist %s', image_path)
464     jpeg_data = gfile.FastGFile(image_path, 'r').read()
465     # 注意,我们实现distorted_image_data作为numpy矩阵在
466     # 发送运行推理的形象之前。
467     # 这涉及到两次内存拷贝和可能在其他实现优化。
468     distorted_image_data = sess.run(distorted_image,
469                                     {input_jpeg_tensor: jpeg_data})
470     bottleneck = run_bottleneck_on_image(sess, distorted_image_data,
471                                          resized_input_tensor,
472                                          bottleneck_tensor)
473     ground_truth = np.zeros(class_count, dtype=np.float32)
474     ground_truth[label_index] = 1.0
475     bottlenecks.append(bottleneck)
476     ground_truthes.append(ground_truth)
477   return bottlenecks, ground_truthes
478 
479 
480 def should_distort_images(flip_left_right, random_crop, random_scale,
481                           random_brightness):
482   """是否启用了扭曲,从输入标志。
483 
484   Args:
485     flip_left_right: Boolean是否随机水平反应图片。
486     random_crop: 整数百分比设置使用总利润在裁剪盒子。
487     random_scale: 改变规模的整数百分比是多少
488     random_brightness: 整数范围内随机乘以像素的值。
489 
490   Returns:
491     Boolean值表明,是否有扭曲的值需要被应用。
492   """
493   return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
494           (random_brightness != 0))
495 
496 
497 def add_input_distortions(flip_left_right, random_crop, random_scale,
498                           random_brightness):
499   """创建操作来应用规定的扭曲。
500 
501   在训练的时候,可以帮助改善结果,如果我们运行图片通过简单扭曲,例如修剪,缩放,翻转。
502 这些反映这种变化是我们在现实世界所期待的,所以可以帮助训练模型,更高效地应对自然的数据。
503 这里我们提供参数和构造一个操作网络在一张照片应用他们。
504 
505   Cropping
506   ~~~~~~~~
507 
508   Cropping is done by placing a bounding box at a random position in the full
509   image. The cropping parameter controls the size of that box relative to the
510   input image. If it's zero, then the box is the same size as the input and no
511   cropping is performed. If the value is 50%, then the crop box will be half the
512   width and height of the input. In a diagram it looks like this:
513 
514   <       width         >
515   +---------------------+
516   |                     |
517   |   width - crop%     |
518   |    <      >         |
519   |    +------+         |
520   |    |      |         |
521   |    |      |         |
522   |    |      |         |
523   |    +------+         |
524   |                     |
525   |                     |
526   +---------------------+
527 
528   Scaling
529   ~~~~~~~
530 
531   Scaling is a lot like cropping, except that the bounding box is always
532   centered and its size varies randomly within the given range. For example if
533   the scale percentage is zero, then the bounding box is the same size as the
534   input and no scaling is applied. If it's 50%, then the bounding box will be in
535   a random range between half the width and height and full size.
536 
537   Args:
538     flip_left_right: Boolean是否随机水平反应图片。
539     random_crop: 整数百分比设置使用总利润在裁剪盒子。
540     random_scale: 改变规模的整数百分比是多少
541     random_brightness: 整数范围内随机乘以像素的值。
542 
543   Returns:
544     jpeg的输入层和扭曲最后的tensor。
545   """
546   jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
547   decoded_image = tf.image.decode_jpeg(jpeg_data)
548   decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
549   decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
550   margin_scale = 1.0 + (random_crop / 100.0)
551   resize_scale = 1.0 + (random_scale / 100.0)
552   margin_scale_value = tf.constant(margin_scale)
553   resize_scale_value = tf.random_uniform(tensor_shape.scalar(),
554                                          minval=1.0,
555                                          maxval=resize_scale)
556   scale_value = tf.mul(margin_scale_value, resize_scale_value)
557   precrop_width = tf.mul(scale_value, MODEL_INPUT_WIDTH)
558   precrop_height = tf.mul(scale_value, MODEL_INPUT_HEIGHT)
559   precrop_shape = tf.pack([precrop_height, precrop_width])
560   precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
561   precropped_image = tf.image.resize_bilinear(decoded_image_4d,
562                                               precrop_shape_as_int)
563   precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
564   cropped_image = tf.random_crop(precropped_image_3d,
565                                  [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH,
566                                   MODEL_INPUT_DEPTH])
567   if flip_left_right:
568     flipped_image = tf.image.random_flip_left_right(cropped_image)
569   else:
570     flipped_image = cropped_image
571   brightness_min = 1.0 - (random_brightness / 100.0)
572   brightness_max = 1.0 + (random_brightness / 100.0)
573   brightness_value = tf.random_uniform(tensor_shape.scalar(),
574                                        minval=brightness_min,
575                                        maxval=brightness_max)
576   brightened_image = tf.mul(flipped_image, brightness_value)
577   distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
578   return jpeg_data, distort_result
579 
580 
581 def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
582   """加一新的软件层和全连接层给训练。
583 
584   我们需要重新训练最顶层来识别你的新类别,所以这个函数加了正确的操作给图表,伴随着一些可变参数
585 来权重,然后设置所有的梯度为向后传递。 
586   设置为SOFTMAX和全连接层是基于:
587   https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
588 
589   Args:
590     class_count: 整数,多少类别的物体我们想要识别。
591     final_tensor_name: 用于生成结果的新的最终节点的名称字符串。
592     bottleneck_tensor: 输出的主CNN图表。
593 
594   Returns:
595     为训练和交叉熵的结果张量,和bottleneck的tensor的输入和ground truth的输入。
596   """
597   bottleneck_input = tf.placeholder_with_default(
598       bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
599       name='BottleneckInputPlaceholder')
600   layer_weights = tf.Variable(
601       tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001),
602       name='final_weights')
603   layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
604   logits = tf.matmul(bottleneck_input, layer_weights,
605                      name='final_matmul') + layer_biases
606   final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
607   ground_truth_input = tf.placeholder(tf.float32,
608                                       [None, class_count],
609                                       name='GroundTruthInput')
610   cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
611       logits, ground_truth_input)
612   cross_entropy_mean = tf.reduce_mean(cross_entropy)
613   train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
614       cross_entropy_mean)
615   return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
616           final_tensor)
617 
618 
619 def add_evaluation_step(result_tensor, ground_truth_tensor):
620   """插入操作,我们需要评估我们的结果的准确性。 
621   Args:
622     result_tensor: 产生结果的新的最终节点。
623     ground_truth_tensor: 我们feed ground truth数据进去节点。
624 
625   Returns:
626 627   """
628   correct_prediction = tf.equal(
629       tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1))
630   evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
631   return evaluation_step
632 
633 
634 def main(_):
635   # 建立预训练图。 
636   maybe_download_and_extract()
637   graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
638       create_inception_graph())
639 
640   # 查看文件夹结构,并创建所有图像的列表。 
641   image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
642                                    FLAGS.validation_percentage)
643   class_count = len(image_lists.keys())
644   if class_count == 0:
645     print('No valid folders of images found at ' + FLAGS.image_dir)
646     return -1
647   if class_count == 1:
648     print('Only one valid folder of images found at ' + FLAGS.image_dir +
649           ' - multiple classes are needed for classification.')
650     return -1
651 
652   # 看看这个命令行标志意味着我们使用任意扭曲。 
653   do_distort_images = should_distort_images(
654       FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
655       FLAGS.random_brightness)
656   sess = tf.Session()
657 
658   if do_distort_images:
659     # 我们将运用扭曲,所以我们需要设置操作。
660     distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions(
661         FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
662         FLAGS.random_brightness)
663   else:
664     # 我们将确保我们已经计算了'bottleneck'图像摘要和缓存
665     cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir,
666                       jpeg_data_tensor, bottleneck_tensor)
667 
668   # 添加新的层,我们将训练。 
669   (train_step, cross_entropy, bottleneck_input, ground_truth_input,
670    final_tensor) = add_final_training_ops(len(image_lists.keys()),
671                                           FLAGS.final_tensor_name,
672                                           bottleneck_tensor)
673 
674   # 将所有的权重设置为初始默认值。 
675   init = tf.initialize_all_variables()
676   sess.run(init)
677 
678   # 创建操作,我们需要评估我们的新层的准确性。
679   evaluation_step = add_evaluation_step(final_tensor, ground_truth_input)
680 
681   # 运行在命令行上的要求的许多周期的训练。 
682   for i in range(FLAGS.how_many_training_steps):
683     #获取一个输入bottleneck值,要么计算新的值每一次,或者从存储在硬盘的缓存获得。
684     if do_distort_images:
685       train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks(
686           sess, image_lists, FLAGS.train_batch_size, 'training',
687           FLAGS.image_dir, distorted_jpeg_data_tensor,
688           distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
689     else:
690       train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks(
691           sess, image_lists, FLAGS.train_batch_size, 'training',
692           FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
693           bottleneck_tensor)
694     # Feed的bottlenecks和ground truth 到图表,并且运行一个训练步骤。
695     sess.run(train_step,
696              feed_dict={bottleneck_input: train_bottlenecks,
697                         ground_truth_input: train_ground_truth})
698     # 每一个如此,打印出有多么好的图形训练。 
699     is_last_step = (i + 1 == FLAGS.how_many_training_steps)
700     if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
701       train_accuracy, cross_entropy_value = sess.run(
702           [evaluation_step, cross_entropy],
703           feed_dict={bottleneck_input: train_bottlenecks,
704                      ground_truth_input: train_ground_truth})
705       print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
706                                                       train_accuracy * 100))
707       print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
708                                                  cross_entropy_value))
709       validation_bottlenecks, validation_ground_truth = (
710           get_random_cached_bottlenecks(
711               sess, image_lists, FLAGS.validation_batch_size, 'validation',
712               FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
713               bottleneck_tensor))
714       validation_accuracy = sess.run(
715           evaluation_step,
716           feed_dict={bottleneck_input: validation_bottlenecks,
717                      ground_truth_input: validation_ground_truth})
718       print('%s: Step %d: Validation accuracy = %.1f%%' %
719             (datetime.now(), i, validation_accuracy * 100))
720 
721   # 我们已经完成了所有的训练,所以在一些新的测试中运行了最后的测试评估 
722   test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks(
723       sess, image_lists, FLAGS.test_batch_size, 'testing',
724       FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
725       bottleneck_tensor)
726   test_accuracy = sess.run(
727       evaluation_step,
728       feed_dict={bottleneck_input: test_bottlenecks,
729                  ground_truth_input: test_ground_truth})
730   print('Final test accuracy = %.1f%%' % (test_accuracy * 100))
731 
732   # 把训练的图表和标签与存储为常量的权重。 
733   output_graph_def = graph_util.convert_variables_to_constants(
734       sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
735   with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
736     f.write(output_graph_def.SerializeToString())
737   with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
738     f.write('\n'.join(image_lists.keys()) + '\n')
739 
740 
741 if __name__ == '__main__':
742   tf.app.run()
view retrain using py

#英语能力有限,哈哈,原文链接:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

    由于时间的原因,我在编译完3的时候被我切断了。四组图片28*28的,每组6000张左右,设置train:validation:test = 4:1:1。编译的照片在TRAINIMAGE,运行的照片在VALIDATEIMAGE,训练了两天半(56个小时)。最后的识别结果如下:
    1、编译0,1,2,3。
user01@user01-forfish:~/Documents/tensorflow$ bazel-bin/tensorflow/examples/image_retraining/retrain --image_dir ~/FISHMNIST2/TRAINIMAGE/
Looking for images in '1'
Looking for images in '0'
Looking for images in '3'
Looking for images in '2'
100 bottleneck files created.
200 bottleneck files created.
300 bottleneck files created.
400 bottleneck files created.
............
24300 bottleneck files created.
24400 bottleneck files created.
24500 bottleneck files created.
24600 bottleneck files created.
24700 bottleneck files created.
2016-04-07 16:03:35.848447: Step 0: Train accuracy = 56.0%
2016-04-07 16:03:35.848516: Step 0: Cross entropy = 1.286888
2016-04-07 16:03:36.059755: Step 0: Validation accuracy = 35.0%
2016-04-07 16:03:37.677323: Step 10: Train accuracy = 92.0%
2016-04-07 16:03:37.677392: Step 10: Cross entropy = 0.810003
2016-04-07 16:03:37.817034: Step 10: Validation accuracy = 84.0%
2016-04-07 16:03:39.249091: Step 20: Train accuracy = 93.0%
2016-04-07 16:03:39.249168: Step 20: Cross entropy = 0.611544
............
2016-04-07 16:07:54.362384: Step 3980: Validation accuracy = 99.0%
2016-04-07 16:07:54.905122: Step 3990: Train accuracy = 100.0%
2016-04-07 16:07:54.905188: Step 3990: Cross entropy = 0.028465
2016-04-07 16:07:54.954875: Step 3990: Validation accuracy = 99.0%
2016-04-07 16:07:55.431090: Step 3999: Train accuracy = 100.0%
2016-04-07 16:07:55.431152: Step 3999: Cross entropy = 0.044311
2016-04-07 16:07:55.480975: Step 3999: Validation accuracy = 97.0%
Final test accuracy = 99.0%
Converted 2 variables to const ops.

    2、使用以上的编译结果。
        (1)识别手写0:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && \
> bazel-bin/tensorflow/examples/label_image/label_image \
> --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
> --output_layer=final_result \
> --image=$HOME/FISHMNIST2/VALIDATEIMAGE/0/MUSIC-1383.jpg
..............
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
  bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 42.557s, Critical Path: 23.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 0 (1): 0.999816
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.000103101
I tensorflow/examples/label_image/main.cc:207] 3 (2): 7.5385e-05
I tensorflow/examples/label_image/main.cc:207] 1 (0): 5.35907e-06

        (2)识别手写1:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/1/JAVA-10000.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
  bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.196s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 1 (0): 0.999564
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.000373375
I tensorflow/examples/label_image/main.cc:207] 3 (2): 4.03145e-05
I tensorflow/examples/label_image/main.cc:207] 0 (1): 2.2671e-05


        (2)识别手写2:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/2/CHUAN-10006.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
  bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.146s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
............
GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.99146
I tensorflow/examples/label_image/main.cc:207] 3 (2): 0.00433606
I tensorflow/examples/label_image/main.cc:207] 1 (0): 0.00381229
I tensorflow/examples/label_image/main.cc:207] 0 (1): 0.000391908


        (3)识别手写3:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/3/YAHOO-1089.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
  bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.121s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 3 (2): 0.99796
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.00196124
I tensorflow/examples/label_image/main.cc:207] 0 (1): 5.19803e-05
I tensorflow/examples/label_image/main.cc:207] 1 (0): 2.67401e-05

    识别率还是蛮高的。

posted @ 2016-04-06 09:54  cestlavie  阅读(2473)  评论(0编辑  收藏  举报