MindSpore易点通·精讲系列--数据集加载之CSVDataset
Dive Into MindSpore – CSVDataset For Dataset Load
MindSpore精讲系列 – 数据集加载之CSVDataset
本文开发环境
- Ubuntu 20.04
- Python 3.8
- MindSpore 1.7.0
本文内容摘要
- 先看API
- 数据准备
- 两种试错
- 正确示例
- 本文总结
- 问题改进
- 本文参考
1. 先看API
老传统,先看看官方文档:
参数解读:
-
dataset_files – 数据集文件路径,可以单文件也可以是文件列表
-
filed_delim – 字段分割符,默认为","
-
column_defaults – 一个巨坑的参数,留待后面解读
-
column_names – 字段名,用于后续数据字段的key
-
num_paraller_workers – 不再解释
-
shuffle – 是否打乱数据,三种选择[False, Shuffle.GLOBAL, Shuffle.FILES]
- Shuffle.GLOBAL – 混洗文件和文件中的数据,默认
- Shuffle.FILES – 仅混洗文件
-
num_shards – 不再解释
-
shard_id – 不再解释
-
cache – 不再解释
2. 数据准备
2.1 数据下载
说明:
使用如下命令下载数据iris.data
和iris.names
到目标目录:
mkdir iris && cd iris
wget -c https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
wget -c https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names
**备注:**如果受系统限制,无法使用wget
命令,可以考虑用浏览器下载,下载地址见说明。
2.2 数据简介
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
更详细的介绍参见官方说明:
5. Number of Instances: 150 (50 in each of three classes)
6. Number of Attributes: 4 numeric, predictive attributes and the class
7. Attribute Information:
1. sepal length in cm
2. sepal width in cm
3. petal length in cm
4. petal width in cm
5. class:
-- Iris Setosa
-- Iris Versicolour
-- Iris Virginica
8. Missing Attribute Values: None
Summary Statistics:
Min Max Mean SD Class Correlation
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
9. Class Distribution: 33.3% for each of 3 classes.
2.3 数据分配
这里对数据进行初步分配,分成训练集和测试集,分配比例为4:1。
相关处理代码如下:
from random import shuffle
def preprocess_iris_data(iris_data_file, train_file, test_file, header=True):
cls_0 = "Iris-setosa"
cls_1 = "Iris-versicolor"
cls_2 = "Iris-virginica"
cls_0_samples = []
cls_1_samples = []
cls_2_samples = []
with open(iris_data_file, "r", encoding="UTF8") as fp:
lines = fp.readlines()
for line in lines:
line = line.strip()
if not line:
continue
if cls_0 in line:
cls_0_samples.append(line)
continue
if cls_1 in line:
cls_1_samples.append(line)
continue
if cls_2 in line:
cls_2_samples.append(line)
shuffle(cls_0_samples)
shuffle(cls_1_samples)
shuffle(cls_2_samples)
print("number of class 0: {}".format(len(cls_0_samples)), flush=True)
print("number of class 1: {}".format(len(cls_1_samples)), flush=True)
print("number of class 2: {}".format(len(cls_2_samples)), flush=True)
train_samples = cls_0_samples[:40] + cls_1_samples[:40] + cls_2_samples[:40]
test_samples = cls_0_samples[40:] + cls_1_samples[40:] + cls_2_samples[40:]
header_content = "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Classes"
with open(train_file, "w", encoding="UTF8") as fp:
if header:
fp.write("{}\n".format(header_content))
for sample in train_samples:
fp.write("{}\n".format(sample))
with open(test_file, "w", encoding="UTF8") as fp:
if header:
fp.write("{}\n".format(header_content))
for sample in test_samples:
fp.write("{}\n".format(sample))
def main():
iris_data_file = "{your_path}/iris/iris.data"
iris_train_file = "{your_path}/iris/iris_train.csv"
iris_test_file = "{your_path}/iris/iris_test.csv"
preprocess_iris_data(iris_data_file, iris_train_file, iris_test_file)
if __name__ == "__main__":
main()
将以上代码保存到preprocess.py
文件,使用如下命令运行:
注意修改相关数据文件路径
python3 preprocess.py
输出内容如下:
number of class 0: 50
number of class 1: 50
number of class 2: 50
同时在目标目录生成iris_train.csv
和iris_test.csv
文件,目录内容如下所示:
.
├── iris.data
├── iris.names
├── iris_test.csv
└── iris_train.csv
3. 两种试错
下面通过几种**错误(带引号)**用法,来初步认识一下CSVDataset
。
3.1 column_defaults是哪样
首先,先来个简单加载,代码如下:
为方便读者复现,这里将shuffle设置为False。
from mindspore.dataset import CSVDataset
def dataset_load(data_files):
column_defaults = [float, float, float, float, str]
column_names = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Classes"]
dataset = CSVDataset(
dataset_files=data_files,
field_delim=",",
column_defaults=column_defaults,
column_names=column_names,
num_samples=None,
shuffle=False)
data_iter = dataset.create_dict_iterator()
item = None
for data in data_iter:
item = data
break
print("====== sample ======\n{}".format(item), flush=True)
def main():
iris_train_file = "{your_path}/iris/iris_train.csv"
dataset_load(data_files=iris_train_file)
if __name__ == "__main__":
main()
将以上代码保存到load.py
文件,运行命令:
注意修改数据文件路径
python3 load.py
纳尼,报错,来看看报错内容:
Traceback (most recent call last):
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in <module>
main()
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
dataset_load(data_files=iris_train_file)
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 75, in dataset_load
dataset = CSVDataset(
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/validators.py", line 1634, in new_method
raise TypeError("column type in column_defaults is invalid.")
TypeError: column type in column_defaults is invalid.
看看引发报错的源码,mindspore/dataset/engine/validators.py
中1634行,相关代码如下:
# check column_defaults
column_defaults = param_dict.get('column_defaults')
if column_defaults is not None:
if not isinstance(column_defaults, list):
raise TypeError("column_defaults should be type of list.")
for item in column_defaults:
if not isinstance(item, (str, int, float)):
raise TypeError("column type in column_defaults is invalid.")
3.1.1 报错分析
更多关于
column_defaults
参数的分析请参考6.1
节。
还记得官方参数说明吗,不记得没关系,这里再列出来。
- column_defaults (list, 可选) - 指定每个数据列的数据类型,有效的类型包括float、int或string。默认值:None,不指定。如果未指定该参数,则所有列的数据类型将被视为string。
很显然,官方参数说明是数据类型,但是到mindspore/dataset/engine/validators.py
代码里面,却检测的是数据实例类型。明确了这点,将代码:
column_defaults = [float, float, float, float, str]
修改为:
这里的数值取自
iris.names
文件,详情参考该文件。
column_defaults = [5.84, 3.05, 3.76, 1.20, "Classes"]
再次运行代码,再次报错:
WARNING: Logging before InitGoogleLogging() is written to STDERR
[ERROR] MD(13306,0x70000269b000,Python):2022-06-14-16:51:59.681.109 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:217] InterruptMaster] Task is terminated with err msg(more detail in info level log):Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
Traceback (most recent call last):
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 107, in <module>
main()
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 103, in main
dataset_load(data_files=iris_train_file)
File "/Users/kaierlong/Documents/Codes/OpenI/Dive_Into_MindSpore/code/chapter_01/csv_dataset.py", line 90, in dataset_load
for data in data_iter:
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 147, in __next__
data = self._get_next()
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 211, in _get_next
raise err
File "/Users/kaierlong/Documents/PyEnv/env_ms_1.7.0/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 204, in _get_next
return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()}
RuntimeError: Unexpected error. Invalid csv, csv file: /Users/kaierlong/Downloads/iris/iris_train.csv parse failed at line 1, type does not match.
Line of code : 506
File : /Users/jenkins/agent-working-dir/workspace/Compile_CPU_X86_MacOS_PY39/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
好了,这个错误我们到3.2
部分进行分析。
3.2 header要不要
在3.1
中,我们根据对报错源码的分析,明确了column_defaults
的正确用法,但是依然存在一个错误。
3.2.1 报错分析
根据报错信息,发现是mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
中506行的报错,相关源码如下:
Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_rows_connector_.get(), field_delim_, column_default_list_, file);
RETURN_IF_NOT_OK(csv_parser.InitCsvParser());
csv_parser.SetStartOffset(start_offset);
csv_parser.SetEndOffset(end_offset);
auto realpath = FileUtils::GetRealPath(file.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
}
std::ifstream ifs;
ifs.open(realpath.value(), std::ifstream::in);
if (!ifs.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + file + ", the file is damaged or permission denied.");
}
if (column_name_list_.empty()) {
std::string tmp;
getline(ifs, tmp);
}
csv_parser.Reset();
try {
while (ifs.good()) {
// when ifstream reaches the end of file, the function get() return std::char_traits<char>::eof()
// which is a 32-bit -1, it's not equal to the 8-bit -1 on Euler OS. So instead of char, we use
// int to receive its return value.
int chr = ifs.get();
int err = csv_parser.ProcessMessage(chr);
if (err != 0) {
// if error code is -2, the returned error is interrupted
if (err == -2) return Status(kMDInterrupted);
RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse csv file: " + file + " at line " +
std::to_string(csv_parser.GetTotalRows() + 1) +
". Error message: " + csv_parser.GetErrorMessage());
}
}
} catch (std::invalid_argument &ia) {
std::string err_row = std::to_string(csv_parser.GetTotalRows() + 1);
RETURN_STATUS_UNEXPECTED("Invalid csv, csv file: " + file + " parse failed at line " + err_row +