
3.1 微调的原理


3.2 数据准备

  • 将数据集切分成训练集和验证集
  • 将数据转换为tfrecord格式


python data_convert.py -t pic/ \
  --train-shards 2 \
  --validation-shards 2 \
  --num-threads 2 \
  --dataset-name satellite
python data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite


"""Converts image data to TFRecords file format with Example protos.
The image data set is expected to reside in JPEG files located in the
following directory structure.
where the sub-directory is the unique label associated with these images.
This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of TFRecord files
where we have selected 1024 and 128 shards for each data set. Each record
within the TFRecord file is a serialized Example proto. The Example proto
contains the following fields:
  image/encoded: string containing JPEG encoded image in RGB colorspace
  image/height: integer, image height in pixels
  image/width: integer, image width in pixels
  image/colorspace: string, specifying the colorspace, always 'RGB'
  image/channels: integer, specifying the number of channels, always 3
  image/format: string, specifying the format, always'JPEG'
  image/filename: string containing the basename of the image file
            e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
  image/class/label: integer specifying the index in a classification layer. start from "class_label_base"
  image/class/text: string specifying the human-readable version of the label
    e.g. 'dog'
If you data set involves bounding boxes, please look at build_imagenet_data.py.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os
import random
import sys
import threading

import numpy as np
import tensorflow as tf
import logging

def _int64_feature(value):
    """Wrapper for inserting int64 features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    value = tf.compat.as_bytes(value)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _convert_to_example(filename, image_buffer, label, text, height, width):
    """Build an Example proto for an example.
      filename: string, path to an image file, e.g., '/path/to/example.JPG'
      image_buffer: string, JPEG encoding of RGB image
      label: integer, identifier for the ground truth for the network
      text: string, unique human-readable, e.g. 'dog'
      height: integer, image height in pixels
      width: integer, image width in pixels
      Example proto

    colorspace = 'RGB'
    channels = 3
    image_format = 'JPEG'

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/colorspace': _bytes_feature(colorspace),
        'image/channels': _int64_feature(channels),
        'image/class/label': _int64_feature(label),
        'image/class/text': _bytes_feature(text),
        'image/format': _bytes_feature(image_format),
        'image/filename': _bytes_feature(os.path.basename(filename)),
        'image/encoded': _bytes_feature(image_buffer)}))
    return example

class ImageCoder(object):
    """Helper class that provides TensorFlow image coding utilities."""

    def __init__(self):
        # Create a single Session to run all image coding calls.
        self._sess = tf.Session()

        # Initializes function that converts PNG to JPEG data.
        self._png_data = tf.placeholder(dtype=tf.string)
        image = tf.image.decode_png(self._png_data, channels=3)
        self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

        # Initializes function that decodes RGB JPEG data.
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

    def png_to_jpeg(self, image_data):
        return self._sess.run(self._png_to_jpeg,
                              feed_dict={self._png_data: image_data})

    def decode_jpeg(self, image_data):
        image = self._sess.run(self._decode_jpeg,
                               feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image

def _is_png(filename):
    """Determine if a file contains a PNG format image.
      filename: string, path of the image file.
      boolean indicating if the image is a PNG.
    return '.png' in filename

def _process_image(filename, coder):
    """Process a single image file.
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    # Read the image file.
    with open(filename, 'rb') as f:
        image_data = f.read()

    # Convert any PNG to JPEG's for consistency.
    if _is_png(filename):
        logging.info('Converting PNG to JPEG for %s' % filename)
        image_data = coder.png_to_jpeg(image_data)

    # Decode the RGB JPEG.
    image = coder.decode_jpeg(image_data)

    # Check that image converted to RGB
    assert len(image.shape) == 3
    height = image.shape[0]
    width = image.shape[1]
    assert image.shape[2] == 3

    return image_data, height, width

def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
                               texts, labels, num_shards, command_args):
    """Processes and saves list of images as TFRecord in 1 thread.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
      thread_index: integer, unique batch to run index is within [0, len(ranges)).
      ranges: list of pairs of integers specifying ranges of each batches to
        analyze in parallel.
      name: string, unique identifier specifying the data set
      filenames: list of strings; each string is a path to an image file
      texts: list of strings; each string is human readable, e.g. 'dog'
      labels: list of integer; each integer identifies the ground truth
      num_shards: integer number of shards for this data set.
    # Each thread produces N shards where N = int(num_shards / num_threads).
    # For instance, if num_shards = 128, and the num_threads = 2, then the first
    # thread would produce shards [0, 64).
    num_threads = len(ranges)
    assert not num_shards % num_threads
    num_shards_per_batch = int(num_shards / num_threads)

    shard_ranges = np.linspace(ranges[thread_index][0],
                               num_shards_per_batch + 1).astype(int)
    num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]

    counter = 0
    for s in range(num_shards_per_batch):
        # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
        shard = thread_index * num_shards_per_batch + s
        output_filename = '%s_%s_%.5d-of-%.5d.tfrecord' % (command_args.dataset_name, name, shard, num_shards)
        output_file = os.path.join(command_args.output_directory, output_filename)
        writer = tf.python_io.TFRecordWriter(output_file)

        shard_counter = 0
        files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
        for i in files_in_shard:
            filename = filenames[i]
            label = labels[i]
            text = texts[i]

            image_buffer, height, width = _process_image(filename, coder)

            example = _convert_to_example(filename, image_buffer, label,
                                          text, height, width)
            shard_counter += 1
            counter += 1

            if not counter % 1000:
                logging.info('%s [thread %d]: Processed %d of %d images in thread batch.' %
                             (datetime.now(), thread_index, counter, num_files_in_thread))

        logging.info('%s [thread %d]: Wrote %d images to %s' %
                     (datetime.now(), thread_index, shard_counter, output_file))
        shard_counter = 0
    logging.info('%s [thread %d]: Wrote %d images to %d shards.' %
                 (datetime.now(), thread_index, counter, num_files_in_thread))

def _process_image_files(name, filenames, texts, labels, num_shards, command_args):
    """Process and save list of images as TFRecord of Example protos.
      name: string, unique identifier specifying the data set
      filenames: list of strings; each string is a path to an image file
      texts: list of strings; each string is human readable, e.g. 'dog'
      labels: list of integer; each integer identifies the ground truth
      num_shards: integer number of shards for this data set.
    assert len(filenames) == len(texts)
    assert len(filenames) == len(labels)

    # Break all images into batches with a [ranges[i][0], ranges[i][1]].
    spacing = np.linspace(0, len(filenames), command_args.num_threads + 1).astype(np.int)
    ranges = []
    for i in range(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i + 1]])

    # Launch a thread for each batch.
    logging.info('Launching %d threads for spacings: %s' % (command_args.num_threads, ranges))

    # Create a mechanism for monitoring when all threads are finished.
    coord = tf.train.Coordinator()

    # Create a generic TensorFlow-based utility for converting all image codings.
    coder = ImageCoder()

    threads = []
    for thread_index in range(len(ranges)):
        args = (coder, thread_index, ranges, name, filenames,
                texts, labels, num_shards, command_args)
        t = threading.Thread(target=_process_image_files_batch, args=args)

    # Wait for all the threads to terminate.
    logging.info('%s: Finished writing all %d images in data set.' %
                 (datetime.now(), len(filenames)))

def _find_image_files(data_dir, labels_file, command_args):
    """Build a list of all images files and labels in the data set.
      data_dir: string, path to the root directory of images.
        Assumes that the image data set resides in JPEG files located in
        the following directory structure.
        where 'dog' is the label associated with these images.
      labels_file: string, path to the labels file.
        The list of valid labels are held in this file. Assumes that the file
        contains entries as such:
        where each line corresponds to a label. We map each label contained in
        the file to an integer starting with the integer 0 corresponding to the
        label contained in the first line.
      filenames: list of strings; each string is a path to an image file.
      texts: list of strings; each string is the class, e.g. 'dog'
      labels: list of integer; each integer identifies the ground truth.
    logging.info('Determining list of input files and labels from %s.' % data_dir)
    unique_labels = [l.strip() for l in tf.gfile.FastGFile(
        labels_file, 'r').readlines()]

    labels = []
    filenames = []
    texts = []

    # Leave label index 0 empty as a background class.
    label_index = command_args.class_label_base

    # Construct the list of JPEG files and labels.
    for text in unique_labels:
        jpeg_file_path = '%s/%s/*' % (data_dir, text)
        matching_files = tf.gfile.Glob(jpeg_file_path)

        labels.extend([label_index] * len(matching_files))
        texts.extend([text] * len(matching_files))

        if not label_index % 100:
            logging.info('Finished finding files in %d of %d classes.' % (
                label_index, len(labels)))
        label_index += 1

    # Shuffle the ordering of all image files in order to guarantee
    # random ordering of the images with respect to label in the
    # saved TFRecord files. Make the randomization repeatable.
    shuffled_index = list(range(len(filenames)))

    filenames = [filenames[i] for i in shuffled_index]
    texts = [texts[i] for i in shuffled_index]
    labels = [labels[i] for i in shuffled_index]

    logging.info('Found %d JPEG files across %d labels inside %s.' %
                 (len(filenames), len(unique_labels), data_dir))
    # print(labels)
    return filenames, texts, labels

def _process_dataset(name, directory, num_shards, labels_file, command_args):
    """Process a complete data set and save it as a TFRecord.
      name: string, unique identifier specifying the data set.
      directory: string, root path to the data set.
      num_shards: integer number of shards for this data set.
      labels_file: string, path to the labels file.
    filenames, texts, labels = _find_image_files(directory, labels_file, command_args)
    _process_image_files(name, filenames, texts, labels, num_shards, command_args)

def check_and_set_default_args(command_args):
    if not (hasattr(command_args, 'train_shards')) or command_args.train_shards is None:
        command_args.train_shards = 5
    if not (hasattr(command_args, 'validation_shards')) or command_args.validation_shards is None:
        command_args.validation_shards = 5
    if not (hasattr(command_args, 'num_threads')) or command_args.num_threads is None:
        command_args.num_threads = 5
    if not (hasattr(command_args, 'class_label_base')) or command_args.class_label_base is None:
        command_args.class_label_base = 0
    if not (hasattr(command_args, 'dataset_name')) or command_args.dataset_name is None:
        command_args.dataset_name = ''
    assert not command_args.train_shards % command_args.num_threads, (
        'Please make the command_args.num_threads commensurate with command_args.train_shards')
    assert not command_args.validation_shards % command_args.num_threads, (
        'Please make the command_args.num_threads commensurate with '
    assert command_args.train_directory is not None
    assert command_args.validation_directory is not None
    assert command_args.labels_file is not None
    assert command_args.output_directory is not None

def main(command_args):
    command_args.train_directory  训练集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
    command_args.validation_directory 验证集所在的文件夹。这个文件夹下面,每个文件夹的名字代表label名称,再下面就是图片。
    command_args.labels_file 一个文件。每一行代表一个label名称。
    command_args.output_directory 一个文件夹,表示最后输出的位置。

    command_args.train_shards 将训练集分成多少份。
    command_args.validation_shards 将验证集分成多少份。
    command_args.num_threads 线程数。必须是上面两个参数的约数。

    command_args.class_label_base 很重要!真正的tfrecord中,每个class的label号从多少开始,默认为0(在models/slim中就是从0开始的)
    command_args.dataset_name 字符串,输出的时候的前缀。

    logging.info('Saving results to %s' % command_args.output_directory)

    # Run it!
    _process_dataset('validation', command_args.validation_directory,
                     command_args.validation_shards, command_args.labels_file, command_args)
    _process_dataset('train', command_args.train_directory,
                     command_args.train_shards, command_args.labels_file, command_args)


# coding:utf-8
from __future__ import absolute_import
import argparse
import os
import logging
from src.tfrecord import main

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--tensorflow-data-dir', default='pic/')
    parser.add_argument('--train-shards', default=2, type=int)
    parser.add_argument('--validation-shards', default=2, type=int)
    parser.add_argument('--num-threads', default=2, type=int)
    parser.add_argument('--dataset-name', default='satellite', type=str)
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    args.tensorflow_dir = args.tensorflow_data_dir
    args.train_directory = os.path.join(args.tensorflow_dir, 'train')
    args.validation_directory = os.path.join(args.tensorflow_dir, 'validation')
    args.output_directory = args.tensorflow_dir
    args.labels_file = os.path.join(args.tensorflow_dir, 'label.txt')
    if os.path.exists(args.labels_file) is False:
        logging.warning('Can\'t find label.txt. Now create it.')
        all_entries = os.listdir(args.train_directory)
        dirnames = []
        for entry in all_entries:
            if os.path.isdir(os.path.join(args.train_directory, entry)):
        with open(args.labels_file, 'w') as f:
            for dirname in dirnames:
                f.write(dirname + '\n')


3.3 使用TensorFlow Slim 微调模型

3.3.1 下载TensorFlow Slim源代码

3.3.2 定义新的datasets 文件


3.3.3 准备训练文件夹


  • 新建一个data 目录,并将第3.2中准备好的5个转换好格式的训练数据复制进去。
  • 新建一个空的train_dir目录,用来保存训练过程中的日志和模型。
  • 新建一个pretrained目录,在slim的GitHub页面找到Inception V3 模型的下载地址http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 下载并解压后,会得到一个inception_v3.ckpt 文件,将该文件复制到pretrained 目录下(这个文件在chapter_3_data/文件中也提供了)

3.3.4 开始训练


python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=2 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
python train_image_classifier.py --train_dir=satellite/train_dir --dataset_name=satellite --dataset_split_name=train --dataset_dir=satellite/data --model_name=inception_v3 --checkpoint_path=satellite/pretrained/inception_v3.ckpt --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --max_number_of_steps=100000 --batch_size=32 --learning_rate=0.001 --learning_rate_decay_type=fixed --save_interval_secs=300 --save_summaries_secs=2 --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004


python train_image_classifier.py \
  --train_dir=satellite/train_dir \
  --dataset_name=satellite \
  --dataset_split_name=train \
  --dataset_dir=satellite/data \
  --model_name=inception_v3 \
  --checkpoint_path=satellite/pretrained/inception_v3.ckpt \
  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
  --max_number_of_steps=100000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=300 \
  --save_summaries_secs=10 \
  --log_every_n_steps=1 \
  --optimizer=rmsprop \
python train_image_classifier.py --train_dir=satellite/train_dir --dataset_name=satellite --dataset_split_name=train --dataset_dir=satellite/data --model_name=inception_v3 --checkpoint_path=satellite/pretrained/inception_v3.ckpt --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --max_number_of_steps=100000 --batch_size=32 --learning_rate=0.001 --learning_rate_decay_type=fixed --save_interval_secs=300 --save_summaries_secs=10 --log_every_n_steps=1 --optimizer=rmsprop --weight_decay=0.00004

3.3.6 验证模型准确率


python eval_image_classifier.py \
  --checkpoint_path=satellite/train_dir \
  --eval_dir=satellite/eval_dir \
  --dataset_name=satellite \
  --dataset_split_name=validation \
  --dataset_dir=satellite/data \
python eval_image_classifier.py --checkpoint_path=satellite/train_dir --eval_dir=satellite/eval_dir --dataset_name=satellite --dataset_split_name=validation --dataset_dir=satellite/data

3.3.7 TensorBoard 可视化与超参数选择


tensorboard --logdir satellite/train_dir

3.3.8 导出模型并对单张图片进行识别


python export_inference_graph.py \
  --alsologtostderr \
  --model_name=inception_v3 \
  --output_file=satellite/inception_v3_inf_graph.pb \
  --dataset_name satellite
python export_inference_graph.py --alsologtostderr --model_name=inception_v3 --output_file=satellite/inception_v3_inf_graph.pb --dataset_name satellite


python freeze_graph.py \
  --input_graph slim/satellite/inception_v3_inf_graph.pb \
  --input_checkpoint slim/satellite/train_dir/model.ckpt-5271 \
  --input_binary true \
  --output_node_names InceptionV3/Predictions/Reshape_1 \
  --output_graph slim/satellite/frozen_graph.pb
python freeze_graph.py --input_graph slim/satellite/inception_v3_inf_graph.pb --input_checkpoint slim/satellite/train_dir/model.ckpt-100 --input_binary true --output_node_names InceptionV3/Predictions/Reshape_1 --output_graph slim/satellite/frozen_graph.pb


python classify_image_inception_v3.py \
  --model_path slim/satellite/frozen_graph.pb \
  --label_path data_prepare/pic/label.txt \
  --image_file test_image.jpg
python classify_image_inception_v3.py --model_path slim/satellite/frozen_graph.pb --label_path data_prepare/pic/label.txt --image_file test_image.jpg


  • TensorFlow Slim 是TensorFlow 中用于定义、训练和验证复杂网络的 高层API。官方已经使用TF-Slim 定义了一些常用的图像识别模型, 如AlexNet、VGGNet、Inception模型、ResNet等。本章介绍的Inception V3 模型也是其中之一, 详细文档请参考: https://github.com/tensorflow/models/tree/master/research/slim。
  • 在第3.2节中,将图片数据转换成了TFRecord文件。TFRecord 是 TensorFlow 提供的用于高速读取数据的文件格式。读者可以参考博文( http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/ )详细了解如何将数据转换为TFRecord 文件,以及 如何从TFRecord 文件中读取数据。
  • Inception V3 是Inception 模型(即GoogLeNet)的改进版,可以参考论文Rethinking the Inception Architecture for Computer Vision 了解 其结构细节。
