TensorFlow(十八):从零开始训练图片分类模型
(一):进入GitHub下载模型--》下载地址
因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。
(1):在slim中新建images文件夹存放图片集
(2):新建model文件夹用来放模型
(3):在datasets文件夹中新建myimages.py文件
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at: tensorflow/models/slim/datasets/download_and_convert_flowers.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 这里根据自己的训练集内容进行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = { 'image': 'A color image of varying size.', 'label': 'A single integer between 0 and 4', } def get_split(split_name, dataset_dir, file_pattern=None, reader=None): """Gets a dataset tuple with instructions for reading flowers. Args: split_name: A train/validation split name. dataset_dir: The base directory of the dataset sources. file_pattern: The file pattern to use when matching the dataset sources. It is assumed that the pattern contains a '%s' string so that the split name can be inserted. reader: The TensorFlow reader type. Returns: A `Dataset` namedtuple. Raises: ValueError: if `split_name` is not a valid train/validation split. """ if split_name not in SPLITS_TO_SIZES: raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern: file_pattern = _FILE_PATTERN file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default. if reader is None: reader = tf.TFRecordReader keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { 'image': slim.tfexample_decoder.Image(), 'label': slim.tfexample_decoder.Tensor('image/class/label'), } decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) labels_to_names = None if dataset_utils.has_labels(dataset_dir): labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset( data_sources=file_pattern, reader=reader, decoder=decoder, num_samples=SPLITS_TO_SIZES[split_name], items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, num_classes=_NUM_CLASSES, labels_to_names=labels_to_names)
(4):修改dataset_factory.py
from datasets import myimages datasets_map = { 'cifar10': cifar10, 'flowers': flowers, 'imagenet': imagenet, 'mnist': mnist, 'myimages':myimages, # 这一句为添加的内容 }
(二):对图片进行处理,生成tfrecord格式的文件。
import tensorflow as tf import os import random import math import sys #验证集数量 _NUM_TEST = 500 #随机种子 _RANDOM_SEED = 0 #数据块数目 _NUM_SHARDS = 5 #数据集路径 DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/" #标签文件名字 LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定义tfrecord文件的路径+名字 def _get_dataset_filename(dataset_dir, split_name, shard_id): output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename) #判断tfrecord文件是否存在 def _dataset_exists(dataset_dir): for split_name in ['train', 'test']: for shard_id in range(_NUM_SHARDS): #定义tfrecord文件的路径+名字 output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id) if not tf.gfile.Exists(output_filename): return False return True #获取所有文件以及分类 def _get_filenames_and_classes(dataset_dir): #数据目录 directories = [] #分类名称 class_names = [] for filename in os.listdir(dataset_dir): #合并文件路径 path = os.path.join(dataset_dir, filename) #判断该路径是否为目录 if os.path.isdir(path): #加入数据目录 directories.append(path) #加入类别名称 class_names.append(filename) photo_filenames = [] #循环每个分类的文件夹 for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) #把图片加入图片列表 photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id): #Abstract base class for protocol messages. return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), })) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME): labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.write('%d:%s\n' % (label, class_name)) #把数据转为TFRecord格式 def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): assert split_name in ['train', 'test'] #计算每个数据块有多少数据 num_per_shard = int(len(filenames) / _NUM_SHARDS) with tf.Graph().as_default(): with tf.Session() as sess: for shard_id in range(_NUM_SHARDS): #定义tfrecord文件的路径+名字 output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: #每一个数据块开始的位置 start_ndx = shard_id * num_per_shard #每一个数据块最后的位置 end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): try: sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id)) sys.stdout.flush() #读取图片 image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误 #获得图片的类别名称 class_name = os.path.basename(os.path.dirname(filenames[i])) #找到类别名称对应的id class_id = class_names_to_ids[class_name] #生成tfrecord文件 example = image_to_tfexample(image_data, b'jpg', class_id) tfrecord_writer.write(example.SerializeToString()) except IOError as e: print("Could not read:",filenames[i]) print("Error:",e) print("Skip it\n") sys.stdout.write('\n') sys.stdout.flush() if __name__ == '__main__': #判断tfrecord文件是否存在 if _dataset_exists(DATASET_DIR): print('tfcecord文件已存在') else: #获得所有图片以及分类 photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR) #把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0} class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集 random.seed(_RANDOM_SEED) random.shuffle(photo_filenames) training_filenames = photo_filenames[_NUM_TEST:] testing_filenames = photo_filenames[:_NUM_TEST] #数据转换 _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR) _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #输出labels文件 labels_to_class_names = dict(zip(range(len(class_names)), class_names)) write_label_file(labels_to_class_names, DATASET_DIR)
(三):新建批处理文件,开始训练模型
python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause
注释:
第一行表示运行训练文件,路径为全路径
第二行表示模型存放位置
第三行为创建的myimages文件名
第四行为使用的训练集
第五行为数据集所在的位置
第六行为批次大小,默认为32,看个人GPU,我用10
第七行为训练次数,默认无限次
第八行为使用模型名称