基于sklearn的数据集划分
https://blog.csdn.net/wade1203/article/details/91453804
from sklearn.model_selection import train_test_split
函数 train_test_split(data, label, test_size = 0.3, random_state =2020 )
参数解释:
data: 待划分样本数据(list形式)
label: 待划分样本数据的标签(list形式)
test_size: 测试数据占样本数据的比例,若整数则是测试数据量,这里设置为0.3,则训练测试比例:0.7:0.3
random_state: 设置随机数种子,保证每次都是同一个随机数。如果为0或者不填,则每次得到数据都不一样
代码示例(使用时没有通过函数直接split 标签文件,这个可以自行修改的):
将image_10000 和 txt_10000文件夹按照7:3的比例分离出 train/img; train/gt 和 test/img; test/gt文件夹,原始文件夹被覆盖
from sklearn.model_selection import train_test_split import os import sys import pathlib from glob import glob from PIL import Image import shutil if __name__ =='__main__': __dir__ = pathlib.Path(os.path.abspath(__file__)) sys.path.append(str(__dir__)) sys.path.append(str(__dir__.parent)) pth_img = './ICPR_text_train_part2_20180313/image_10000/' pth_txt = './ICPR_text_train_part2_20180313/txt_10000/' pth_new_tmp='./ICPR_text_train_part2_20180313/img/' pth_new_tmp1 = './ICPR_text_train_part2_20180313/gt/' pth_new = './ICPR_text_train_part2_20180313/test' if not os.path.exists(pth_new): os.mkdir(pth_new) #img_test_pth = os.path.join(pth_new,'img') gt_test_pth = os.path.join(pth_new,'gt') # if not os.path.exists(img_test_pth): # os.mkdir(img_test_pth) # if not os.path.exists(gt_test_pth): # os.mkdir(gt_test_pth) pth_1 = './ICPR_text_train_part2_20180313/train' if not os.path.exists(pth_1): os.mkdir(pth_1) img_train_pth = os.path.join(pth_1,'img') gt_train_pth = os.path.join(pth_1,'gt') if not os.path.exists(img_train_pth): os.mkdir(img_train_pth) if not os.path.exists(gt_train_pth): os.mkdir(gt_train_pth) # gif2jpg(pth) files = [img for img in os.listdir(pth_img) if img.endswith('jpg')] train, test = train_test_split(files,test_size=0.3,random_state=2020) #train:img gt test: img gt print('train:{} images,test:{} images'.format(len(train),len(test))) i=0 for line in train: ori_pth_img = pth_img+line line_txt = os.path.splitext(line)[0]+'.txt' ori_pth_txt = pth_txt+line_txt des_pth_img = os.path.join(img_train_pth,line) des_pth_txt = os.path.join(gt_train_pth,line_txt) shutil.move(ori_pth_img,des_pth_img) shutil.move(ori_pth_txt,des_pth_txt) i=i+1 print('move {} imgs totally'.format(i)) os.rename(pth_img,pth_new_tmp) os.rename(pth_txt,pth_new_tmp1) shutil.move(pth_new_tmp,pth_new) shutil.move(pth_new_tmp1,pth_new)