自己写的制作 city的语义分割tfrecord 适用于deeplabv3+

自己写的制作 city的语义分割tfrecord  适用于deeplabv3+

自用

"""Converts PASCAL dataset to TFRecords file format."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import io
import os
import sys
import natsort
import PIL.Image
import tensorflow as tf

from utils import dataset_util

parser = argparse.ArgumentParser()

parser.add_argument('--data_dir', type=str, default='/home/a/dataset/cityscapes/',
                    help='Path to the directory containing the PASCAL VOC data.')

parser.add_argument('--output_path', type=str, default='./dataset',
                    help='Path to the directory to create TFRecords outputs.')

parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt',
                    help='Path to the file listing the training data.')

parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt',
                    help='Path to the file listing the validation data.')

parser.add_argument('--image_data_dir', type=str, default='leftImg8bit',
                    help='The directory containing the image data.')

parser.add_argument('--label_data_dir', type=str, default='gtFine',
                    help='The directory containing the augmented label data.')


def dict_to_tf_example(image_path,
                       label_path):
  """Convert image and label to tf.Example proto.

  Args:
    image_path: Path to a single PASCAL image.
    label_path: Path to its corresponding label.

  Returns:
    example: The converted tf.Example.

  Raises:
    ValueError: if the image pointed to by image_path is not a valid JPEG or
                if the label pointed to by label_path is not a valid PNG or
                if the size of image does not match with that of label.
  """
  with tf.gfile.GFile(image_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'PNG':
    raise ValueError('Image format not PNG')

  with tf.gfile.GFile(label_path, 'rb') as fid:
    encoded_label = fid.read()
  encoded_label_io = io.BytesIO(encoded_label)
  label = PIL.Image.open(encoded_label_io)
  if label.format != 'PNG':
    raise ValueError('Label format not PNG')

  if image.size != label.size:
    raise ValueError('The size of image does not match with that of label.')

  width, height = image.size

  example = tf.train.Example(features=tf.train.Features(feature={
    'image/height': dataset_util.int64_feature(height),
    'image/width': dataset_util.int64_feature(width),
    'image/encoded': dataset_util.bytes_feature(encoded_jpg),
    'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
    'label/encoded': dataset_util.bytes_feature(encoded_label),
    'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
  }))
  return example
def scanDir_img_File(dir):
    for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
        for f in files:
            yield os.path.join(root,f)

def scanDir_lable_File(dir):
    for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
        # 处理该文件夹下所有文件:

        for f in files:
            if os.path.isfile(os.path.join(root, f)):
                a = os.path.splitext(f)
                lable = a[0].split('_')[4]
                # print(lable)
                if lable in ('labelTrainIds'):
                    # print(os.path.join(root,f))
                    yield os.path.join(root,f)

def create_tf_record(output_filename,
                     image_dir,
                     label_dir):
  """Creates a TFRecord file from examples.

  Args:
    output_filename: Path to where output file is saved.
    image_dir: Directory where image files are stored.
    label_dir: Directory where label files are stored.
  """
  imgg = []
  writer = tf.python_io.TFRecordWriter(output_filename)

  img = scanDir_img_File(image_dir)
  for imgs in img:
    imgg.append(imgs)
  image_list = natsort.natsorted(imgg)

  lable = scanDir_lable_File(label_dir)
  lablee = []
  for lables in lable:
    lablee.append(lables)
  label_list = natsort.natsorted(lablee)
  for image_path,label_path in zip(image_list,label_list):
    print(image_path,label_path)
    try:
      tf_example = dict_to_tf_example(image_path, label_path)
      writer.write(tf_example.SerializeToString())
    except ValueError:
      tf.logging.warning('Invalid example: %s, ignoring.')

  writer.close()


def main(unused_argv):
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  tf.logging.info("Reading from CITY dataset")
  train_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir,'train')
  train_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir,'train')
  val_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir, 'val')
  val_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir, 'val')

  train_output_path = os.path.join(FLAGS.output_path, 'city_train.record')
  val_output_path = os.path.join(FLAGS.output_path, 'city_val.record')

  create_tf_record(train_output_path, train_image_dir, train_label_dir)
  create_tf_record(val_output_path, val_image_dir, val_label_dir)


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

 

 

 
posted @ 2018-03-23 18:00  ayew  阅读(462)  评论(0编辑  收藏  举报