python 拆分多类别数据集

原数据集形式,收集的数据来源包括两个folder, 数据分为三类(class1-3)

 

 希望得到的数据集形式: 将数据集拆分为train和test两部分,每部分都包含所有类别。

 

完整代码(已包含注释,自测可用,参考文献:数据集划分、label生成及按label将图片分类到不同文件夹):

 1 import os
 2 # import cv2
 3 import random
 4 import sys
 5 from random import randint
 6 import shutil
 7 
 8 def fileExist(path1):
 9     if os.path.exists(path1):
10         return
11     else:
12         try:
13             os.mkdir(path1)  # 创建单层文件夹
14         except Exception as e:
15             os.makedirs(path1)  # 创建多层文件夹
16 
17 
18 def split_dataset(root_path, new_path, ratio=0.7):  # root: folder1: new_path: dataset1/folder1 按0.7的比例拆分,也可按其他比例
19     folder_list = os.listdir(root_path)  # folder1/[class1,class2...]
20     for folder in folder_list:  # class1
21         train_path = os.path.join(new_path, "train", str(folder))
22         test_path = os.path.join(new_path, "test", str(folder))
23         origin_path = os.path.join(root_path, str(folder))
24         img_list = os.listdir(origin_path)
25 
26         img_num = len(img_list)
27         train_num = int(img_num * ratio)
28         train_sample = random.sample(img_list, train_num)
29         test_sample = list(set(img_list)-set(train_sample))
30 
31         for item in train_sample:
32             src_new = os.path.join(origin_path, str(item))
33             dst_new = os.path.join(train_path, str(item))
34             shutil.copy(src=src_new, dst = dst_new)
35         for item in test_sample:
36             src_new = os.path.join(origin_path, str(item))
37             dst_new = os.path.join(test_path, str(item))
38             shutil.copy(src=src_new, dst=dst_new)
39 
40 
41 if __name__ == '__main__':
42     root_path = "dataset"
43     new_path = "dataset1"
44 
45     # 创建文件夹
46     for domain in os.listdir(root_path):
47         domain_path = os.path.join(root_path, str(domain))
48         domain_new_path = os.path.join(new_path, str(domain))
49         for folder in os.listdir(domain_path):  # class1
50             train_path = os.path.join(domain_new_path, "train", str(folder))
51             test_path = os.path.join(domain_new_path, "test", str(folder))
52             fileExist(train_path)
53             fileExist(test_path)
54 
55     # 拆分数据集到新的路径
56     for domain in os.listdir(root_path):
57         domain_path = os.path.join(root_path, str(domain))
58         domain_new_path = os.path.join(new_path, str(domain))
59         split_dataset(domain_path,domain_new_path

 

posted @ 2021-08-10 13:00  achived  阅读(1089)  评论(0编辑  收藏  举报