数据集划分
1 import random 2 import os 3 4 random.seed(0) 5 6 source_path = '/data1/zjh/FFHQ/1024' 7 source_list = os.listdir(source_path) 8 divided_path = '/data1/zjh/FFHQ_divided' 9 if not os.path.exists(os.path.join(divided_path, 'train')): 10 os.makedirs(os.path.join(divided_path, 'train')) 11 if not os.path.exists(os.path.join(divided_path, 'val')): 12 os.makedirs(os.path.join(divided_path, 'val')) 13 14 eval_index = random.sample(source_list, k=int(70000 * 3//10)) 15 for index, source_list_name in enumerate(source_list): 16 print(index) 17 # eval_index 中保存验证集val的图像名称 18 if source_list_name in eval_index: 19 os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'val'))) 20 else: 21 os.system("cp %s %s" % (os.path.join(source_path, source_list_name), os.path.join(divided_path, 'train')))