MindSpore:MindSpore易点通·精讲系列--数据集加载之MindDataset
Dive Into MindSpore – MindDataset For Dataset Load
MindSpore易点通·精讲系列–数据集加载之MindDataset
本文开发环境
- Ubuntu 20.04
- Python 3.8
- MindSpore 1.7.0
本文内容摘要
- 背景介绍
- 先看文档
- 数据生成
- 数据加载
- 问题解答
- 本文总结
- 本文参考
1. 背景介绍
在前面的文章中,我们介绍了ImageFolderDataset
、CSVDataset
及TFRecordDataset
三个数据集加载API。
本文为数据集加载部分的最后一篇文章(当然,如果后续读者有需要,再考虑补充其他API
精讲),我们将介绍MindSpore
中官方数据格式MindRecord
加载所涉及的API
的MindDataset
。
一个完整的机器学习工作流包括数据集读取(可能包含数据处理)、模型定义、模型训练、模型评估。如何在工作流中更好的读取数据,是各个深度学习框架需要解决的一个重要问题。为此,TensorFlow
推出了TFRecord
数据格式,而MindSpore
给出的解决方案就是MindRecord
。在正式开始本文的讲解之前,先来看看MindRecord
数据格式的特点:
- 实现数据统一存储、访问,使得训练时数据读取更加简便。
- 数据聚合存储、高效读取,使得训练时数据方便管理和移动。
- 高效的数据编解码操作,使得用户可以对数据操作无感知。
- 可以灵活控制数据切分的分区大小,实现分布式数据处理。
2. 先看文档
老传统,先看官方文档。
下面对官方文档中的参数,做简单解读:
dataset_files
– 类型为字符串或者列表。如果为字符串则按照匹配规则自动寻找并加载相应前缀的MindRecord
文件;如果为列表,则读取列表内的MindRecord
文件,即列表内要为具体的文件名。columns_list
– 指定从MindRecord
数据文件中读取的数据字段,或者说数据列。默认值为None
,即读取全部字段或数据列。- 其他参数参见之前文章中的相关解读。
3. 数据生成
本文使用的是
THUCNews
数据集,如果需要将该数据集用于商业用途,请联系数据集作者。
在上面API
解读中,我们讲到MindDasetset
读取的是MindRecord
文件,下面就来介绍一下如何生成MindRecord
数据文件。
MindRecord
数据文件生成可以简单包含以下几个部分(非顺序):
- 读取及处理原始数据
- 声明
MindRecord
文件格式 - 定义
MindRecord
数据字段 - 添加
MindRecord
索引字段 - 写入
MindRecord
数据内容
3.1 生成代码
下面我们基于THUCNews
数据集,来生成MindRecord
数据。
3.1.1 代码部分
import codecs
import os
import re
import numpy as np
from collections import Counter
from mindspore.mindrecord import FileWriter
def get_txt_files(data_dir):
cls_txt_dict = {}
txt_file_list = []
# get files list and class files list.
sub_data_name_list = next(os.walk(data_dir))[1]
sub_data_name_list = sorted(sub_data_name_list)
for sub_data_name in sub_data_name_list:
sub_data_dir = os.path.join(data_dir, sub_data_name)
data_name_list = next(os.walk(sub_data_dir))[2]
data_file_list = [os.path.join(sub_data_dir, data_name) for data_name in data_name_list]
cls_txt_dict[sub_data_name] = data_file_list
txt_file_list.extend(data_file_list)
num_data_files = len(data_file_list)
print("{}: {}".format(sub_data_name, num_data_files), flush=True)
num_txt_files = len(txt_file_list)
print("total: {}".format(num_txt_files), flush=True)
return cls_txt_dict, txt_file_list
def get_txt_data(txt_file):
with codecs.open(txt_file, "r", "UTF8") as fp:
txt_content = fp.read()
txt_data = re.sub("\s+", " ", txt_content)
return txt_data
def build_vocab(txt_file_list, vocab_size=7000):
counter = Counter()
for txt_file in txt_file_list:
txt_data = get_txt_data(txt_file)
counter.update(txt_data)
num_vocab = len(counter)
if num_vocab < vocab_size - 1:
real_vocab_size = num_vocab + 2
else:
real_vocab_size = vocab_size
# pad_id is 0, unk_id is 1
vocab_dict = {word_freq[0]: ix + 1 for ix, word_freq in enumerate(counter.most_common(real_vocab_size - 2))}
print("real vocab size: {}".format(real_vocab_size), flush=True)
print("vocab dict:\n{}".format(vocab_dict), flush=True)
return vocab_dict
def make_mindrecord_files(
data_dir, mindrecord_dir, vocab_size=7000, min_seq_length=10, max_seq_length=800,
num_train_shard=16, num_test_shard=4):
# get txt files
cls_txt_dict, txt_file_list = get_txt_files(data_dir=data_dir)
# map word to id
vocab_dict = build_vocab(txt_file_list=txt_file_list, vocab_size