用一个Inception v3 架构模型实现简单的迁移学习(译:../tensorflow/tensorflow/examples/image_retraining/retrain.py)
1 # Copyright 2015 Google Inc. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """用一个Inception v3 架构模型实现简单的迁移学习。 16 这个例子展示了如何实现一个Inception v3架构模型,在ImageNet图片的训练,并且训练一层可以被别的物类识别的新顶层。 17 顶层,对于每张图片可以接受输入2048-维向量。在此之上,我们训练一个软件层。假定软件层包含N个labels,这相应的需要 18 学习N+2048×N个模型参数,还有相应的学习偏差和权重。 19 这里有一个例子,假设你有一个文件夹包含类名子文件夹,每一个子文件夹标签包含照片。例子文件夹flower_photos应该是 20 这样的结构: 21 ~/flower_photos/daisy/photo1.jpg 22 ~/flower_photos/daisy/photo2.jpg 23 ... 24 ~/flower_photos/rose/anotherphoto77.jpg 25 ... 26 ~/flower_photos/sunflower/somepicture.jpg 27 子文件夹的命名是好非常重要的,一旦他们被定义,那么类别的标签就会被应用到该子文件夹里的每张照片,但是照片本身的名字可以 28 随便。如果你准备好了照片,你可以开始进行训练,使用这些命令: 29 bazel build third_party/tensorflow/examples/image_retraining:retrain && \ 30 bazel-bin/third_party/tensorflow/examples/image_retraining/retrain \ 31 --image_dir ~/flower_photos 32 你可以替换image_dir参数,使用任意的文件夹,只要包含子文件夹,且子文件夹里包含照片。对应每张照片的标签是子文件夹的名字。 33 这样产生的新模型文件可以被任意的Tensorflow程序加载和执行,例如label_iamge的代码。 34 """ 35 from __future__ import absolute_import 36 from __future__ import division 37 from __future__ import print_function 38 39 from datetime import datetime 40 import glob 41 import hashlib 42 import os.path 43 import random 44 import re 45 import sys 46 import tarfile 47 48 import numpy as np 49 from six.moves import urllib 50 import tensorflow as tf 51 52 from tensorflow.python.client import graph_util 53 from tensorflow.python.framework import tensor_shape 54 from tensorflow.python.platform import gfile 55 56 57 FLAGS = tf.app.flags.FLAGS 58 59 # 输入和输出文件标志 60 tf.app.flags.DEFINE_string('image_dir', '', 61 """图片文件夹的路径。""") 62 tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb', 63 """训练图表保存到哪里?""") 64 tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt', 65 """训练图表的标签保存到哪里?""") 66 67 # 详细的训练配置 68 tf.app.flags.DEFINE_integer('how_many_training_steps', 4000, 69 """在结束之前,需要训练多少步?""") 70 tf.app.flags.DEFINE_float('learning_rate', 0.01, 71 """在训练的时候设置多大的学习率?""") 72 tf.app.flags.DEFINE_integer( 73 'testing_percentage', 10, 74 """图片用于测试的百分比""") 75 tf.app.flags.DEFINE_integer( 76 'validation_percentage', 10, 77 """图片用于检定的百分比""") 78 tf.app.flags.DEFINE_integer('eval_step_interval', 10, 79 """评估培训结果的频率?""") 80 tf.app.flags.DEFINE_integer('train_batch_size', 100, 81 """一次训练多少张照片?""") 82 tf.app.flags.DEFINE_integer('test_batch_size', 500, 83 """一次测试多少张照片?""" 84 """这个测试集只使用很少来验证""" 85 """模型的整体精度。""") 86 tf.app.flags.DEFINE_integer( 87 'validation_batch_size', 100, 88 """有多少图片在一个评估批量使用。这个验证集""" 89 """被使用的频率比测试集多, 这也是一个早期的指标""" 90 """模型有多精确在训练期间。""") 91 92 # 文件系统cache所在目录 93 tf.app.flags.DEFINE_string('model_dir', '/tmp/imagenet', 94 """classify_image_graph_def.pb,""" 95 """imagenet_synset_to_human_label_map.txt, and """ 96 """imagenet_2012_challenge_label_map_proto.pbtxt的路径。""") 97 tf.app.flags.DEFINE_string( 98 'bottleneck_dir', '/tmp/bottleneck', 99 """cache bottleneck 层作为值的文件集。""") 100 tf.app.flags.DEFINE_string('final_tensor_name', 'final_result', 101 """分类输出层的名称""" 102 """在重新训练时。""") 103 104 # 控制扭曲参数在训练期间 105 tf.app.flags.DEFINE_boolean( 106 'flip_left_right', False, 107 """是否随机对半水平翻转训练图片。""") 108 tf.app.flags.DEFINE_integer( 109 'random_crop', 0, 110 """A percentage determining how much of a margin to randomly crop off the""" 111 """ training images.""") 112 tf.app.flags.DEFINE_integer( 113 'random_scale', 0, 114 """A percentage determining how much to randomly scale up the size of the""" 115 """ training images by.""") 116 tf.app.flags.DEFINE_integer( 117 'random_brightness', 0, 118 """A percentage determining how much to randomly multiply the training""" 119 """ image input pixels up or down by.""") 120 121 # 这些都是参数绑定到特定的模型架构 122 # 我们使用的Inceptionv3。这些包括张量名称及其大小。 123 # 如果你想适应这个脚本使用到另一个模型,您将需要更新这些反映网络中使用的值。 124 # pylint: disable=line-too-long 125 DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 126 # pylint: enable=line-too-long 127 BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' 128 BOTTLENECK_TENSOR_SIZE = 2048 129 MODEL_INPUT_WIDTH = 299 130 MODEL_INPUT_HEIGHT = 299 131 MODEL_INPUT_DEPTH = 3 132 JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' 133 RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0' 134 135 136 def create_image_lists(image_dir, testing_percentage, validation_percentage): 137 """建立训练图像的文件系统的列表。 138 分析了图像中的子文件夹目录,分裂成稳定的训练、测试和验证集,并返回一个数据结构来描述图像的列表为每个标签及其路径。 139 Args: 140 image_dir:字符串路径文件夹包含图片的子文件夹。 141 testing_percentage:图像的整数百分比保留给测试。 142 validation_percentage:图像的整数百分比保留给校正。 143 144 Return:一个文件夹包含每个子文件夹入口标签,伴随着图片被分成training,testing和validation集 145 """ 146 if not gfile.Exists(image_dir): 147 print("Image directory '" + image_dir + "' not found.") 148 return None 149 result = {} 150 sub_dirs = [x[0] for x in os.walk(image_dir)] 151 # The root directory comes first, so skip it. 152 is_root_dir = True 153 for sub_dir in sub_dirs: 154 if is_root_dir: 155 is_root_dir = False 156 continue 157 extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 158 file_list = [] 159 dir_name = os.path.basename(sub_dir) 160 if dir_name == image_dir: 161 continue 162 print("Looking for images in '" + dir_name + "'") 163 for extension in extensions: 164 file_glob = os.path.join(image_dir, dir_name, '*.' + extension) 165 file_list.extend(glob.glob(file_glob)) 166 if not file_list: 167 print('No files found') 168 continue 169 if len(file_list) < 20: 170 print('WARNING: Folder has less than 20 images, which may cause issues.') 171 label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) 172 training_images = [] 173 testing_images = [] 174 validation_images = [] 175 for file_name in file_list: 176 base_name = os.path.basename(file_name) 177 #我们想要忽略任何东西在'_nohash_'的文件名称,当决定将把一个图像, 178 数据集的创造者的分组照片接近彼此的变化。例如这是用于植物病害数据设置为组的多个图片相同的叶子。 179 hash_name = re.sub(r'_nohash_.*$', '', file_name) 180 #这看起来有点不可思议,但是我们需要决定这文件是否应该被写入training,testing或者validation集,并且 181 我们希望保持现有的文件在同一组即使随后制造更多的文件。 182 我们需要一个稳定的方式决定基于文件名本身,所以我们做一个散列,然后用它来生成一个概率值,我们使用分配它。 183 hash_name_hashed = hashlib.sha1(hash_name.encode('utf-8')).hexdigest() 184 percentage_hash = (int(hash_name_hashed, 16) % (65536)) * (100 / 65535.0) 185 if percentage_hash < validation_percentage: 186 validation_images.append(base_name) 187 elif percentage_hash < (testing_percentage + validation_percentage): 188 testing_images.append(base_name) 189 else: 190 training_images.append(base_name) 191 result[label_name] = { 192 'dir': dir_name, 193 'training': training_images, 194 'testing': testing_images, 195 'validation': validation_images, 196 } 197 return result 198 199 def get_image_path(image_lists, label_name, index, image_dir, category): 200 """返回一个路径为一个标签图像给定索引。 201 202 Args: 203 image_lists:每一个标签的训练图片字典。 204 label_name:我们想要一个图像标签字符串。 205 index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。 206 image_dir:根目录包含子目录字符,里面包含训练照片。 207 category:名称字符串设置将图像从培训,测试,或验证获取。 208 Return: 209 文件系统路径字符串满足图像请求的参数。 210 """ 211 212 if label_name not in image_lists: 213 tf.logging.fatal('Label does not exist %s.', label_name) 214 label_lists = image_lists[label_name] 215 if category not in label_lists: 216 tf.logging.fatal('Category does not exist %s.', category) 217 category_list = label_lists[category] 218 if not category_list: 219 tf.logging.fatal('Category has no images - %s.', category) 220 mod_index = index % len(category_list) 221 base_name = category_list[mod_index] 222 sub_dir = label_lists['dir'] 223 full_path = os.path.join(image_dir, sub_dir, base_name) 224 return full_path 225 226 def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, 227 category): 228 """在给定的索引标签,返回一个路径给一个bottleneck文件。 229 Asgs: 230 image_lists:每一个标签的训练图片字典。 231 label_name:我们想要一个图像标签字符串。 232 index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。 233 bottleneck:bottleneck值的缓存字符文件夹 234 image_dir:根目录包含子目录字符,里面包含训练照片。 235 category:名称字符串设置将图像从培训,测试,或验证获取。 236 Returns: 237 符合要求的参数照片的文件系统路径字符 238 """ 239 return get_image_path(image_lists, label_name, index, bottleneck_dir, 240 category) + '.txt' 241 242 def create_inception_graph(): 243 """从保存的GraphDef文件创建一个图形,然后返回一个图形对象。 244 Returns: 245 包含已经训练过的Inception网络图形,和各种各样的tensors,我们将要控制的。 246 """ 247 with tf.Session() as sess: 248 model_filename = os.path.join( 249 FLAGS.model_dir, 'classify_image_graph_def.pb') 250 with gfile.FastGFile(model_filename, 'rb') as f: 251 graph_def = tf.GraphDef() 252 graph_def.ParseFromString(f.read()) 253 bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = ( 254 tf.import_graph_def(graph_def, name='', return_elements=[ 255 BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME, 256 RESIZED_INPUT_TENSOR_NAME])) 257 return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 258 259 260 def run_bottleneck_on_image(sess, image_data, image_data_tensor, 261 bottleneck_tensor): 262 """在一张照片运行推断提取'bottleneck'简要层 263 Args: 264 sess: 当前活跃的Tensorflow Session 265 image_data: 图片数据的Numpy矩阵 266 image_data_tensor: 在图表中的输入数据层 267 bottleneck_tensor: 在最后软件层之前的层 268 269 Returns: 270 bottleneck值的Numpy矩阵。 271 """ 272 bottleneck_values = sess.run( 273 bottleneck_tensor, 274 {image_data_tensor: image_data}) 275 bottleneck_values = np.squeeze(bottleneck_values) 276 return bottleneck_values 277 278 279 def maybe_download_and_extract(): 280 """下载和解压模型的tar文件 281 如果在重新训练模型之前,我们使用的不存在,这个函数会从Tensorflow.org网站下载,然后解压它到一个文件夹。 282 """ 283 dest_directory = FLAGS.model_dir 284 if not os.path.exists(dest_directory): 285 os.makedirs(dest_directory) 286 filename = DATA_URL.split('/')[-1] 287 filepath = os.path.join(dest_directory, filename) 288 if not os.path.exists(filepath): 289 290 def _progress(count, block_size, total_size): 291 sys.stdout.write('\r>> Downloading %s %.1f%%' % 292 (filename, 293 float(count * block_size) / float(total_size) * 100.0)) 294 sys.stdout.flush() 295 296 filepath, _ = urllib.request.urlretrieve(DATA_URL, 297 filepath, 298 reporthook=_progress) 299 print() 300 statinfo = os.stat(filepath) 301 print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 302 tarfile.open(filepath, 'r:gz').extractall(dest_directory) 303 304 305 def ensure_dir_exists(dir_name): 306 """确保文件夹在硬盘里存在。 307 308 Args: 309 dir_name: 我们想要创建的路径字符文件夹 310 """ 311 if not os.path.exists(dir_name): 312 os.makedirs(dir_name) 313 314 315 def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, 316 category, bottleneck_dir, jpeg_data_tensor, 317 bottleneck_tensor): 318 """重新找回或者计算一张照片的bottleneck值 319 320 如果在硬盘上存在一个cached的版本,放回它,否则计算值并保存到硬盘,供下次使用。 321 322 Args: 323 sess: 当前活跃的Tensorflow Session 324 image_lists: 每一个标签的训练图片字典。 325 label_name: 我们想要一个图像标签字符串。 326 index:我们想要的图片偏移值,这会被取模通过图片label的有效值,所以它的意义重大。。 327 category:名称字符串设置将图像从培训,测试,或验证获取。 328 bottleneck_dir: 包含bottleneck值的缓存文件夹字符。 329 jpeg_data_tensor: 获取加载jpeg数据进入tensor。 330 bottleneck_tensor: bottleneck值的输出tensor。 331 332 Returns: 333 为一张图片而长生的Numpy矩阵值,在bottleneck层。 334 """ 335 label_lists = image_lists[label_name] 336 sub_dir = label_lists['dir'] 337 sub_dir_path = os.path.join(bottleneck_dir, sub_dir) 338 ensure_dir_exists(sub_dir_path) 339 bottleneck_path = get_bottleneck_path(image_lists, label_name, index, 340 bottleneck_dir, category) 341 if not os.path.exists(bottleneck_path): 342 print('Creating bottleneck at ' + bottleneck_path) 343 image_path = get_image_path(image_lists, label_name, index, image_dir, 344 category) 345 if not gfile.Exists(image_path): 346 tf.logging.fatal('File does not exist %s', image_path) 347 image_data = gfile.FastGFile(image_path, 'rb').read() 348 bottleneck_values = run_bottleneck_on_image(sess, image_data, 349 jpeg_data_tensor, 350 bottleneck_tensor) 351 bottleneck_string = ','.join(str(x) for x in bottleneck_values) 352 with open(bottleneck_path, 'w') as bottleneck_file: 353 bottleneck_file.write(bottleneck_string) 354 355 with open(bottleneck_path, 'r') as bottleneck_file: 356 bottleneck_string = bottleneck_file.read() 357 bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 358 return bottleneck_values 359 360 361 def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, 362 jpeg_data_tensor, bottleneck_tensor): 363 """保证所有的training,testing和validation的bottleneck都缓存完毕。 364 因为我们可能会读取同一张照片很多次(如果在训练的时候没有运用扭曲)它可以加快速度很多,如果我们为每个图像计算bottleneck值一次在预处理层 365 然后在训练的时候反复地只读这些缓存值。这我们就完成了所有我们发现的所有照片,计算这些值,然后保存他们。 366 367 Args: 368 sess: 当前活跃的Tensorflow Session 369 image_lists: 每一个标签的训练图片字典。 370 image_dir:根目录包含子目录字符,里面包含训练照片。 371 bottleneck_dir: 包含bottleneck值的缓存文件夹字符。 372 jpeg_data_tensor: 获取加载jpeg数据进入tensor。 373 bottleneck_tensor: 倒数第二输出层的图。 374 375 Returns: 376 无 377 """ 378 how_many_bottlenecks = 0 379 ensure_dir_exists(bottleneck_dir) 380 for label_name, label_lists in image_lists.items(): 381 for category in ['training', 'testing', 'validation']: 382 category_list = label_lists[category] 383 for index, unused_base_name in enumerate(category_list): 384 get_or_create_bottleneck(sess, image_lists, label_name, index, 385 image_dir, category, bottleneck_dir, 386 jpeg_data_tensor, bottleneck_tensor) 387 how_many_bottlenecks += 1 388 if how_many_bottlenecks % 100 == 0: 389 print(str(how_many_bottlenecks) + ' bottleneck files created.') 390 391 392 def get_random_cached_bottlenecks(sess, image_lists, how_many, category, 393 bottleneck_dir, image_dir, jpeg_data_tensor, 394 bottleneck_tensor): 395 """重新得到为缓存图片获得bottleneck值 396 397 如果没有应用扭曲,这函数能够重新得到缓存bottleneck的值,直接从硬盘。它挑选一组随机的照片 398 从特定类别。 399 400 Args: 401 sess: 当前活跃的Tensorflow Session 402 image_lists: 每一个标签的训练图片字典。 403 how_many: 返回bottleneck值的数量 404 category: 名称字符串设置将图像从培训,测试,或验证获取。 405 bottleneck_dir: 包含bottleneck值的缓存文件夹字符。 406 image_dir: 根目录包含子目录字符,里面包含训练照片。 407 jpeg_data_tensor: 获取加载jpeg数据进入tensor。 408 bottleneck_tensor: 倒数第二输出层的CNN图。 409 410 Returns: 411 bottleneck列表和他们相应的ground truthes. 412 """ 413 class_count = len(image_lists.keys()) 414 bottlenecks = [] 415 ground_truthes = [] 416 for unused_i in range(how_many): 417 label_index = random.randrange(class_count) 418 label_name = list(image_lists.keys())[label_index] 419 image_index = random.randrange(65536) 420 bottleneck = get_or_create_bottleneck(sess, image_lists, label_name, 421 image_index, image_dir, category, 422 bottleneck_dir, jpeg_data_tensor, 423 bottleneck_tensor) 424 ground_truth = np.zeros(class_count, dtype=np.float32) 425 ground_truth[label_index] = 1.0 426 bottlenecks.append(bottleneck) 427 ground_truthes.append(ground_truth) 428 return bottlenecks, ground_truthes 429 430 431 def get_random_distorted_bottlenecks( 432 sess, image_lists, how_many, category, image_dir, input_jpeg_tensor, 433 distorted_image, resized_input_tensor, bottleneck_tensor): 434 """扭曲之后,为训练的照片重新获得bottleneck的值。 435 436 如果我们训练使用扭曲,例如修剪,缩放,翻转,我们不得不重新计算整个模型的每张照片,所以我们不能用缓存里bottleneck的 437 值。相反,我们为请求类别寻找随机照片,通过扭曲的图表来执行,然后整个图表为每个获得bottleneck结果。 438 439 Args: 440 sess: 当前活跃的Tensorflow Session 441 image_lists: 每一个标签的训练图片字典。 442 how_many: 返回bottleneck值的数量 443 category: 名称字符串设置将图像从培训,测试,或验证获取。 444 image_dir: 根目录包含子目录字符,里面包含训练照片。 445 input_jpeg_tensor: 我们输入图片数据的输入层。 446 distorted_image: 输出节点的扭曲图表 447 resized_input_tensor: 输入节点的识别图表。 448 bottleneck_tensor: 倒数第二输出层的CNN图。 449 450 Returns: 451 bottleneck列表和他们相应的ground truthes. 452 """ 453 class_count = len(image_lists.keys()) 454 bottlenecks = [] 455 ground_truthes = [] 456 for unused_i in range(how_many): 457 label_index = random.randrange(class_count) 458 label_name = image_lists.keys()[label_index] 459 image_index = random.randrange(65536) 460 image_path = get_image_path(image_lists, label_name, image_index, image_dir, 461 category) 462 if not gfile.Exists(image_path): 463 tf.logging.fatal('File does not exist %s', image_path) 464 jpeg_data = gfile.FastGFile(image_path, 'r').read() 465 # 注意,我们实现distorted_image_data作为numpy矩阵在 466 # 发送运行推理的形象之前。 467 # 这涉及到两次内存拷贝和可能在其他实现优化。 468 distorted_image_data = sess.run(distorted_image, 469 {input_jpeg_tensor: jpeg_data}) 470 bottleneck = run_bottleneck_on_image(sess, distorted_image_data, 471 resized_input_tensor, 472 bottleneck_tensor) 473 ground_truth = np.zeros(class_count, dtype=np.float32) 474 ground_truth[label_index] = 1.0 475 bottlenecks.append(bottleneck) 476 ground_truthes.append(ground_truth) 477 return bottlenecks, ground_truthes 478 479 480 def should_distort_images(flip_left_right, random_crop, random_scale, 481 random_brightness): 482 """是否启用了扭曲,从输入标志。 483 484 Args: 485 flip_left_right: Boolean是否随机水平反应图片。 486 random_crop: 整数百分比设置使用总利润在裁剪盒子。 487 random_scale: 改变规模的整数百分比是多少 488 random_brightness: 整数范围内随机乘以像素的值。 489 490 Returns: 491 Boolean值表明,是否有扭曲的值需要被应用。 492 """ 493 return (flip_left_right or (random_crop != 0) or (random_scale != 0) or 494 (random_brightness != 0)) 495 496 497 def add_input_distortions(flip_left_right, random_crop, random_scale, 498 random_brightness): 499 """创建操作来应用规定的扭曲。 500 501 在训练的时候,可以帮助改善结果,如果我们运行图片通过简单扭曲,例如修剪,缩放,翻转。 502 这些反映这种变化是我们在现实世界所期待的,所以可以帮助训练模型,更高效地应对自然的数据。 503 这里我们提供参数和构造一个操作网络在一张照片应用他们。 504 505 Cropping 506 ~~~~~~~~ 507 508 Cropping is done by placing a bounding box at a random position in the full 509 image. The cropping parameter controls the size of that box relative to the 510 input image. If it's zero, then the box is the same size as the input and no 511 cropping is performed. If the value is 50%, then the crop box will be half the 512 width and height of the input. In a diagram it looks like this: 513 514 < width > 515 +---------------------+ 516 | | 517 | width - crop% | 518 | < > | 519 | +------+ | 520 | | | | 521 | | | | 522 | | | | 523 | +------+ | 524 | | 525 | | 526 +---------------------+ 527 528 Scaling 529 ~~~~~~~ 530 531 Scaling is a lot like cropping, except that the bounding box is always 532 centered and its size varies randomly within the given range. For example if 533 the scale percentage is zero, then the bounding box is the same size as the 534 input and no scaling is applied. If it's 50%, then the bounding box will be in 535 a random range between half the width and height and full size. 536 537 Args: 538 flip_left_right: Boolean是否随机水平反应图片。 539 random_crop: 整数百分比设置使用总利润在裁剪盒子。 540 random_scale: 改变规模的整数百分比是多少 541 random_brightness: 整数范围内随机乘以像素的值。 542 543 Returns: 544 jpeg的输入层和扭曲最后的tensor。 545 """ 546 jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') 547 decoded_image = tf.image.decode_jpeg(jpeg_data) 548 decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32) 549 decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) 550 margin_scale = 1.0 + (random_crop / 100.0) 551 resize_scale = 1.0 + (random_scale / 100.0) 552 margin_scale_value = tf.constant(margin_scale) 553 resize_scale_value = tf.random_uniform(tensor_shape.scalar(), 554 minval=1.0, 555 maxval=resize_scale) 556 scale_value = tf.mul(margin_scale_value, resize_scale_value) 557 precrop_width = tf.mul(scale_value, MODEL_INPUT_WIDTH) 558 precrop_height = tf.mul(scale_value, MODEL_INPUT_HEIGHT) 559 precrop_shape = tf.pack([precrop_height, precrop_width]) 560 precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32) 561 precropped_image = tf.image.resize_bilinear(decoded_image_4d, 562 precrop_shape_as_int) 563 precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0]) 564 cropped_image = tf.random_crop(precropped_image_3d, 565 [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, 566 MODEL_INPUT_DEPTH]) 567 if flip_left_right: 568 flipped_image = tf.image.random_flip_left_right(cropped_image) 569 else: 570 flipped_image = cropped_image 571 brightness_min = 1.0 - (random_brightness / 100.0) 572 brightness_max = 1.0 + (random_brightness / 100.0) 573 brightness_value = tf.random_uniform(tensor_shape.scalar(), 574 minval=brightness_min, 575 maxval=brightness_max) 576 brightened_image = tf.mul(flipped_image, brightness_value) 577 distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult') 578 return jpeg_data, distort_result 579 580 581 def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor): 582 """加一新的软件层和全连接层给训练。 583 584 我们需要重新训练最顶层来识别你的新类别,所以这个函数加了正确的操作给图表,伴随着一些可变参数 585 来权重,然后设置所有的梯度为向后传递。 586 设置为SOFTMAX和全连接层是基于: 587 https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html 588 589 Args: 590 class_count: 整数,多少类别的物体我们想要识别。 591 final_tensor_name: 用于生成结果的新的最终节点的名称字符串。 592 bottleneck_tensor: 输出的主CNN图表。 593 594 Returns: 595 为训练和交叉熵的结果张量,和bottleneck的tensor的输入和ground truth的输入。 596 """ 597 bottleneck_input = tf.placeholder_with_default( 598 bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE], 599 name='BottleneckInputPlaceholder') 600 layer_weights = tf.Variable( 601 tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), 602 name='final_weights') 603 layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') 604 logits = tf.matmul(bottleneck_input, layer_weights, 605 name='final_matmul') + layer_biases 606 final_tensor = tf.nn.softmax(logits, name=final_tensor_name) 607 ground_truth_input = tf.placeholder(tf.float32, 608 [None, class_count], 609 name='GroundTruthInput') 610 cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 611 logits, ground_truth_input) 612 cross_entropy_mean = tf.reduce_mean(cross_entropy) 613 train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize( 614 cross_entropy_mean) 615 return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, 616 final_tensor) 617 618 619 def add_evaluation_step(result_tensor, ground_truth_tensor): 620 """插入操作,我们需要评估我们的结果的准确性。 621 Args: 622 result_tensor: 产生结果的新的最终节点。 623 ground_truth_tensor: 我们feed ground truth数据进去节点。 624 625 Returns: 626 无 627 """ 628 correct_prediction = tf.equal( 629 tf.argmax(result_tensor, 1), tf.argmax(ground_truth_tensor, 1)) 630 evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, 'float')) 631 return evaluation_step 632 633 634 def main(_): 635 # 建立预训练图。 636 maybe_download_and_extract() 637 graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = ( 638 create_inception_graph()) 639 640 # 查看文件夹结构,并创建所有图像的列表。 641 image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, 642 FLAGS.validation_percentage) 643 class_count = len(image_lists.keys()) 644 if class_count == 0: 645 print('No valid folders of images found at ' + FLAGS.image_dir) 646 return -1 647 if class_count == 1: 648 print('Only one valid folder of images found at ' + FLAGS.image_dir + 649 ' - multiple classes are needed for classification.') 650 return -1 651 652 # 看看这个命令行标志意味着我们使用任意扭曲。 653 do_distort_images = should_distort_images( 654 FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 655 FLAGS.random_brightness) 656 sess = tf.Session() 657 658 if do_distort_images: 659 # 我们将运用扭曲,所以我们需要设置操作。 660 distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions( 661 FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 662 FLAGS.random_brightness) 663 else: 664 # 我们将确保我们已经计算了'bottleneck'图像摘要和缓存 665 cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir, 666 jpeg_data_tensor, bottleneck_tensor) 667 668 # 添加新的层,我们将训练。 669 (train_step, cross_entropy, bottleneck_input, ground_truth_input, 670 final_tensor) = add_final_training_ops(len(image_lists.keys()), 671 FLAGS.final_tensor_name, 672 bottleneck_tensor) 673 674 # 将所有的权重设置为初始默认值。 675 init = tf.initialize_all_variables() 676 sess.run(init) 677 678 # 创建操作,我们需要评估我们的新层的准确性。 679 evaluation_step = add_evaluation_step(final_tensor, ground_truth_input) 680 681 # 运行在命令行上的要求的许多周期的训练。 682 for i in range(FLAGS.how_many_training_steps): 683 #获取一个输入bottleneck值,要么计算新的值每一次,或者从存储在硬盘的缓存获得。 684 if do_distort_images: 685 train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks( 686 sess, image_lists, FLAGS.train_batch_size, 'training', 687 FLAGS.image_dir, distorted_jpeg_data_tensor, 688 distorted_image_tensor, resized_image_tensor, bottleneck_tensor) 689 else: 690 train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks( 691 sess, image_lists, FLAGS.train_batch_size, 'training', 692 FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 693 bottleneck_tensor) 694 # Feed的bottlenecks和ground truth 到图表,并且运行一个训练步骤。 695 sess.run(train_step, 696 feed_dict={bottleneck_input: train_bottlenecks, 697 ground_truth_input: train_ground_truth}) 698 # 每一个如此,打印出有多么好的图形训练。 699 is_last_step = (i + 1 == FLAGS.how_many_training_steps) 700 if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 701 train_accuracy, cross_entropy_value = sess.run( 702 [evaluation_step, cross_entropy], 703 feed_dict={bottleneck_input: train_bottlenecks, 704 ground_truth_input: train_ground_truth}) 705 print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i, 706 train_accuracy * 100)) 707 print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, 708 cross_entropy_value)) 709 validation_bottlenecks, validation_ground_truth = ( 710 get_random_cached_bottlenecks( 711 sess, image_lists, FLAGS.validation_batch_size, 'validation', 712 FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 713 bottleneck_tensor)) 714 validation_accuracy = sess.run( 715 evaluation_step, 716 feed_dict={bottleneck_input: validation_bottlenecks, 717 ground_truth_input: validation_ground_truth}) 718 print('%s: Step %d: Validation accuracy = %.1f%%' % 719 (datetime.now(), i, validation_accuracy * 100)) 720 721 # 我们已经完成了所有的训练,所以在一些新的测试中运行了最后的测试评估 722 test_bottlenecks, test_ground_truth = get_random_cached_bottlenecks( 723 sess, image_lists, FLAGS.test_batch_size, 'testing', 724 FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 725 bottleneck_tensor) 726 test_accuracy = sess.run( 727 evaluation_step, 728 feed_dict={bottleneck_input: test_bottlenecks, 729 ground_truth_input: test_ground_truth}) 730 print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) 731 732 # 把训练的图表和标签与存储为常量的权重。 733 output_graph_def = graph_util.convert_variables_to_constants( 734 sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) 735 with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: 736 f.write(output_graph_def.SerializeToString()) 737 with gfile.FastGFile(FLAGS.output_labels, 'w') as f: 738 f.write('\n'.join(image_lists.keys()) + '\n') 739 740 741 if __name__ == '__main__': 742 tf.app.run()
#英语能力有限,哈哈,原文链接:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py
由于时间的原因,我在编译完3的时候被我切断了。四组图片28*28的,每组6000张左右,设置train:validation:test = 4:1:1。编译的照片在TRAINIMAGE,运行的照片在VALIDATEIMAGE,训练了两天半(56个小时)。最后的识别结果如下:
1、编译0,1,2,3。
user01@user01-forfish:~/Documents/tensorflow$ bazel-bin/tensorflow/examples/image_retraining/retrain --image_dir ~/FISHMNIST2/TRAINIMAGE/
Looking for images in '1'
Looking for images in '0'
Looking for images in '3'
Looking for images in '2'
100 bottleneck files created.
200 bottleneck files created.
300 bottleneck files created.
400 bottleneck files created.
............
24300 bottleneck files created.
24400 bottleneck files created.
24500 bottleneck files created.
24600 bottleneck files created.
24700 bottleneck files created.
2016-04-07 16:03:35.848447: Step 0: Train accuracy = 56.0%
2016-04-07 16:03:35.848516: Step 0: Cross entropy = 1.286888
2016-04-07 16:03:36.059755: Step 0: Validation accuracy = 35.0%
2016-04-07 16:03:37.677323: Step 10: Train accuracy = 92.0%
2016-04-07 16:03:37.677392: Step 10: Cross entropy = 0.810003
2016-04-07 16:03:37.817034: Step 10: Validation accuracy = 84.0%
2016-04-07 16:03:39.249091: Step 20: Train accuracy = 93.0%
2016-04-07 16:03:39.249168: Step 20: Cross entropy = 0.611544
............
2016-04-07 16:07:54.362384: Step 3980: Validation accuracy = 99.0%
2016-04-07 16:07:54.905122: Step 3990: Train accuracy = 100.0%
2016-04-07 16:07:54.905188: Step 3990: Cross entropy = 0.028465
2016-04-07 16:07:54.954875: Step 3990: Validation accuracy = 99.0%
2016-04-07 16:07:55.431090: Step 3999: Train accuracy = 100.0%
2016-04-07 16:07:55.431152: Step 3999: Cross entropy = 0.044311
2016-04-07 16:07:55.480975: Step 3999: Validation accuracy = 97.0%
Final test accuracy = 99.0%
Converted 2 variables to const ops.
2、使用以上的编译结果。
(1)识别手写0:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && \
> bazel-bin/tensorflow/examples/label_image/label_image \
> --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
> --output_layer=final_result \
> --image=$HOME/FISHMNIST2/VALIDATEIMAGE/0/MUSIC-1383.jpg
..............
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 42.557s, Critical Path: 23.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 0 (1): 0.999816
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.000103101
I tensorflow/examples/label_image/main.cc:207] 3 (2): 7.5385e-05
I tensorflow/examples/label_image/main.cc:207] 1 (0): 5.35907e-06
(2)识别手写1:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/1/JAVA-10000.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.196s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 1 (0): 0.999564
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.000373375
I tensorflow/examples/label_image/main.cc:207] 3 (2): 4.03145e-05
I tensorflow/examples/label_image/main.cc:207] 0 (1): 2.2671e-05
(2)识别手写2:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/2/CHUAN-10006.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.146s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
............
GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.99146
I tensorflow/examples/label_image/main.cc:207] 3 (2): 0.00433606
I tensorflow/examples/label_image/main.cc:207] 1 (0): 0.00381229
I tensorflow/examples/label_image/main.cc:207] 0 (1): 0.000391908
(3)识别手写3:
user01@user01-forfish:~/Documents/tensorflow$ bazel build tensorflow/examples/label_image:label_image && bazel-bin/tensorflow/examples/label_image/label_image --graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt --output_layer=final_result --image=$HOME/FISHMNIST2/VALIDATEIMAGE/3/YAHOO-1089.jpg
INFO: Found 1 target...
Target //tensorflow/examples/label_image:label_image up-to-date:
bazel-bin/tensorflow/examples/label_image/label_image
INFO: Elapsed time: 0.121s, Critical Path: 0.00s
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in
............
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
W tensorflow/core/kernels/batch_norm_op.cc:36] Op is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
I tensorflow/examples/label_image/main.cc:207] 3 (2): 0.99796
I tensorflow/examples/label_image/main.cc:207] 2 (3): 0.00196124
I tensorflow/examples/label_image/main.cc:207] 0 (1): 5.19803e-05
I tensorflow/examples/label_image/main.cc:207] 1 (0): 2.67401e-05
识别率还是蛮高的。