数据集划分

 1 import random
 2 import os
 3 
 4 random.seed(0)
 5 
 6 source_path = '/data1/zjh/FFHQ/1024'
 7 source_list = os.listdir(source_path)
 8 divided_path = '/data1/zjh/FFHQ_divided'
 9 if not os.path.exists(os.path.join(divided_path, 'train')):
10     os.makedirs(os.path.join(divided_path, 'train'))
11 if not os.path.exists(os.path.join(divided_path, 'val')):
12     os.makedirs(os.path.join(divided_path, 'val'))
13 
14 eval_index = random.sample(source_list, k=int(70000 * 3//10))
15 for index, source_list_name in enumerate(source_list):
16     print(index)
17     # eval_index 中保存验证集val的图像名称
18     if source_list_name in eval_index:
19         os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'val')))
20     else:
21         os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'train')))

 

posted @ 2022-12-12 11:32  Jiahui_Zhan  阅读(41)  评论(0编辑  收藏  举报