MindSpore易点通·精讲系列--数据集加载之TFRecordDataset

Dive Into MindSpore -- TFRecordDataset For Dataset Load

MindSpore易点通·精讲系列--数据集加载之TFRecordDataset

本文开发环境

  • Ubuntu 20.04
  • Python 3.8
  • MindSpore 1.7.0

本文内容摘要

  • 背景介绍
  • 先看文档
  • 生成TFRecord
  • 数据加载
  • 本文总结
  • 本文参考

1. 背景介绍

TFRecord格式是TensorFlow官方设计的一种数据格式。

TFRecord 格式是一种用于存储二进制记录序列的简单格式,该格式能够更好的利用内存,内部包含多个tf.train.Example,在一个Examples消息体中包含一系列的tf.train.feature属性,而每一个feature是一个key-value的键值对,其中key是string类型,value的取值有三种:

  • bytes_list:可以存储stringbyte两种数据类型
  • float_list:可以存储float(float32)double(float64)两种数据类型
  • int64_list:可以存储bool, enum, int32, uint32, int64, uint64数据类型

上面简单介绍了TFRecord的知识,下面我们就要进入正题,来谈谈MindSpore中对TFRecord格式的支持。

2. 先看文档

老传统,先来看看官方对API的描述。

下面对主要参数做简单介绍:

  • dataset_files -- 数据集文件路径。
  • schema -- 读取模式策略,通俗来说就是要读取的tfrecord文件内的数据内容格式。可以通过json或者Schema传入。默认为None不指定。
  • columns_list -- 指定读取的具体数据列。默认全部读取。
  • num_samples -- 指定读取出来的样本数量。
  • shuffle -- 是否对数据进行打乱,可参考之前的文章解读。

3. 生成TFRecord

本文使用的是THUCNews数据集,如果需要将该数据集用于商业用途,请联系数据集作者。

数据集启智社区下载地址

由于下文需要用到TFRecord数据集来做加载,本节先来生成TFRecord数据集。对TensorFlow不了解的读者可以直接照搬代码即可。

生成TFRecord代码如下:

import codecs
import os
import re
import six
import tensorflow as tf

from collections import Counter


def _int64_feature(values):
    """Returns a TF-Feature of int64s.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _float32_feature(values):
    """Returns a TF-Feature of float32s.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
     """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))


def _bytes_feature(values):
    """Returns a TF-Feature of bytes.
    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))


def convert_to_feature(values):
    """Convert to TF-Feature based on the type of element in values.

    Args:
        values: A scalar or list of values.

    Returns:
        A TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]

    if isinstance(values[0], int):
        return _int64_feature(values)
    elif isinstance(values[0], float):
        return _float32_feature(values)
    elif isinstance(values[0], bytes):
        return _bytes_feature(values)
    else:
        raise ValueError("feature type {0} is not supported now !".format(type(values[0])))


def dict_to_example(dictionary):
    """Converts a dictionary of string->int to a tf.Example."""
    features = {}
    for k, v in six.iteritems(dictionary):
        features[k] = convert_to_feature(values=v)
    return tf.train.Example(features=tf.train.Features(feature=features))


def get_txt_files(data_dir):
    cls_txt_dict = {}
    txt_file_list = []

    # get files list and class files list.
    sub_data_name_list = next(os.walk(data_dir))[1]
    sub_data_name_list = sorted(sub_data_name_list)
    for sub_data_name in sub_data_name_list:
        sub_data_dir = os.path.join(data_dir, sub_data_name)
        data_name_list = next(os.walk(sub_data_dir))[2]
        data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
        cls_txt_dict[sub_data_name] = data_file_list
        txt_file_list.extend(data_file_list)
        num_data_files = len(data_file_list)
        print("{}: {}".format(sub_data_name, num_data_files), flush=True)
    num_txt_files = len(txt_file_list)
    print("total: {}".format(num_txt_files), flush=True)

    return cls_txt_dict, txt_file_list


def get_txt_data(txt_file):
    with codecs.open(txt_file, "r", "UTF8") as fp:
        txt_content = fp.read()
    txt_data = re.sub("\s+", " ", txt_content)

    return txt_data


def build_vocab(txt_file_list, vocab_size=7000):
    counter = Counter()
    for txt_file in txt_file_list:
        txt_data = get_txt_data(txt_file)
        counter.update(txt_data)

    num_vocab = len(counter)
    if num_vocab < vocab_size - 1:
        real_vocab_size = num_vocab + 2
    else:
        real_vocab_size = vocab_size

    # pad_id is 0, unk_id is 1
    vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}

    print("real vocab size: {}".format(real_vocab_size), flush=True)
    print("vocab dict:\n{}".format(vocab_dict), flush=True)

    return vocab_dict


def make_tfrecords(
        data_dir, tfrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800,
        num_train=8, num_test=2, start_fid=0):
    # get txt files
    cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir)
    # map word to id
    vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size=vocab_size)
    # map class to id
    class_dict = {class_name: ix for ix, class_name in enumerate(cls_txt_dict.keys())}

    train_writers = []
    for fid in range(start_fid, num_train+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "train_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)
        train_writers.append(writer)

    test_writers = []
    for fid in range(start_fid, num_test+start_fid):
        tfrecord_file = os.path.join(tfrecord_dir, "test_{:04d}.tfrecord".format(fid))
        writer = tf.io.TFRecordWriter(tfrecord_file)
        test_writers.append(writer)

    pad_id = 0
    unk_id = 1
    num_samples = 0
    num_train_samples = 0
    num_test_samples = 0
    for class_name, class_file_list in cls_txt_dict.items():
        class_id = class_dict[class_name]
        num_class_pass = 0
        for txt_file in class_file_list:
            txt_data = get_txt_data(txt_file=txt_file)
            txt_len = len(txt_data)
            if txt_len < min_seq_length:
                num_class_pass += 1
                continue
            if txt_len > max_seq_length:
                txt_data = txt_data[:max_seq_length]
                txt_len = max_seq_length
            word_ids = []
            for word in txt_data:
                word_id = vocab_dict.get(word, unk_id)
                word_ids.append(word_id)
            for _ in range(max_seq_length - txt_len):
                word_ids.append(pad_id)

            example = dict_to_example({"input": word_ids, "class": class_id})
            num_samples += 1
            if num_samples % 10 == 0:
                num_test_samples += 1
                writer_id = num_test_samples % num_test
                test_writers[writer_id].write(example.SerializeToString())
            else:
                num_train_samples += 1
                writer_id = num_train_samples % num_train
                train_writers[writer_id].write(example.SerializeToString())
        print("{} pass: {}".format(class_name, num_class_pass), flush=True)

    for writer in train_writers:
        writer.close()
    for writer in test_writers:
        writer.close()

    print("num samples: {}".format(num_samples), flush=True)
    print("num train samples: {}".format(num_train_samples), flush=True)
    print("num test samples: {}".format(num_test_samples), flush=True)


def main():
    data_dir = "{your_data_dir}"
    tfrecord_dir = "{your_tfrecord_dir}"
    make_tfrecords(data_dir=data_dir, tfrecord_dir=tfrecord_dir)


if __name__ == "__main__":
    main()

将以上代码保存到文件make_tfrecord.py,运行命令:

注意:需要替换data_dirtfrecord_dir为个人目录。

python3 make_tfrecord.py

使用tree命令查看生成的TFRecord数据目录,输出内容如下:

.
├── test_0000.tfrecord
├── test_0001.tfrecord
├── train_0000.tfrecord
├── train_0001.tfrecord
├── train_0002.tfrecord
├── train_0003.tfrecord
├── train_0004.tfrecord
├── train_0005.tfrecord
├── train_0006.tfrecord
└── train_0007.tfrecord

0 directories, 10 files

4. 数据加载

有了3中的TFRecord数据集,下面来介绍如何在MindSpore中使用该数据集。

4.1 schema使用

4.1.1 不指定schema

首先来看看对于参数schema不指定,即采用默认值的情况下,能否正确读取数据。

代码如下:

import os

from mindspore.common import dtype as mstype
from mindspore.dataset import Schema
from mindspore.dataset import TFRecordDataset


def get_tfrecord_files(tfrecord_dir, file_suffix="tfrecord", is_train=True):
    if not os.path.exists(tfrecord_dir):
        raise ValueError("tfrecord directory: {} not exists!".format(tfrecord_dir))

    if is_train:
        file_prefix = "train"
    else:
        file_prefix = "test"

    data_sources = []
    for parent, _, filenames in os.walk(tfrecord_dir):
        for filename in filenames:
            if not filename.startswith(file_prefix):
                continue
            tmp_path = os.path.join(parent, filename)
            if tmp_path.endswith(file_suffix):
                data_sources.append(tmp_path)
    return data_sources


def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break


def main():
    tfrecord_dir = "{your_tfrecord_dir}"
    tfrecord_json = "{your_tfrecord_json_file}"
    load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=None)


if __name__ == "__main__":
    main()

代码解读:

  • get_tfrecord_files -- 获取指定的TFRecord文件列表
  • load_tfrecord -- 数据集加载

将上述代码保存到文件load_tfrecord_dataset.py,运行如下命令:

python3 load_tfrecord_dataset.py

输出内容如下:

可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。

{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18, 
......
  135,  979,    1,   35,  166,  181,   90,  143])}

4.1.2 使用Schema对象

下面介绍,如何使用mindspore.dataset.Schema来指定读取模型策略。

修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

代码解读:

  • 这里使用了Schema对象,并且指定了列名,列的数据类型和数据维度。

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

可以看出能正确解析出之前保存在TFRecord内的数据,数据类型和数据维度解析正确。

{'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18,  742,  330,  385,  999,  837,   56,  529, 1000,
.....
  135,  979,    1,   35,  166,  181,   90,  143]), 'class': Tensor(shape=[1], dtype=Int64, value= [0])}

4.1.3 使用JSON文件

下面介绍,如何使用JSON文件来指定读取模型策略。

新建tfrecord_sample.json文件,在文件内写入如下内容:

numRows -- 数据列数

columns -- 依次为每列的列名、数据类型、数据维数、数据维度。

{
  "datasetType": "TF",
  "numRows": 2,
  "columns": {
    "input": {
      "type": "int64",
      "rank": 1,
      "shape": [800]
    },
    "class" : {
      "type": "int64",
      "rank": 1,
      "shape": [1]
    }
  }
}

有了相应的JSON文件,下面来介绍如何使用该文件进行数据读取。

修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=tfrecord_json, shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

同时修改main部分代码如下:

load_tfrecord(tfrecord_dir=tfrecord_dir, tfrecord_json=tfrecord_json)

代码解读

  • 这里直接将schema参数指定为JSON的文件路径

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

{'class': Tensor(shape=[1], dtype=Int64, value= [0]), 'input': Tensor(shape=[800], dtype=Int64, value= [1719,  636, 1063,   18, ......
  135,  979,    1,   35,  166,  181,   90,  143])}

4.2 columns_list使用

在某些场景下,我们可能只需要某(几)列的数据,而非全部数据,这时候就可以通过制定columns_list来进行数据加载。

下面我们只读取class列,来简单看看如何操作。

4.1.2基础上,修改load_tfrecord代码如下:

def load_tfrecord(tfrecord_dir, tfrecord_json=None):
    tfrecord_files = get_tfrecord_files(tfrecord_dir)
    # print("tfrecord files:\n{}".format("\n".join(tfrecord_files)), flush=True)

    data_schema = Schema()
    data_schema.add_column(name="input", de_type=mstype.int64, shape=[800])
    data_schema.add_column(name="class", de_type=mstype.int64, shape=[1])

    dataset = TFRecordDataset(dataset_files=tfrecord_files, schema=data_schema, columns_list=["class"], shuffle=False)

    data_iter = dataset.create_dict_iterator()
    for item in data_iter:
        print(item, flush=True)
        break

保存并再次运行文件load_tfrecord_dataset.py,输出内容如下:

可以看到只读取了我们指定的列,且数据加载正确。

{'class': Tensor(shape=[1], dtype=Int64, value= [0])}

5. 本文总结

本文介绍了在MindSpore中如何加载TFRecord数据集,并重点介绍了TFRecordDataset中的schemacolumns_list参数使用。

6. 本文参考

本文为原创文章,版权归作者所有,未经授权不得转载!

posted @ 2022-08-12 10:36  Skytier  阅读(129)  评论(0编辑  收藏  举报