基于sklearn的数据集划分

https://blog.csdn.net/wade1203/article/details/91453804

 

from sklearn.model_selection import train_test_split

函数 train_test_split(data, label, test_size = 0.3, random_state =2020 )

参数解释:

   data: 待划分样本数据(list形式)

   label: 待划分样本数据的标签(list形式)

  test_size: 测试数据占样本数据的比例,若整数则是测试数据量,这里设置为0.3,则训练测试比例:0.7:0.3

  random_state: 设置随机数种子,保证每次都是同一个随机数。如果为0或者不填,则每次得到数据都不一样

代码示例(使用时没有通过函数直接split 标签文件,这个可以自行修改的):

将image_10000  和  txt_10000文件夹按照7:3的比例分离出  train/img; train/gt 和 test/img; test/gt文件夹,原始文件夹被覆盖

from sklearn.model_selection import train_test_split
import os
import sys
import pathlib
from glob import glob
from PIL import Image
import shutil

if __name__ =='__main__':
    __dir__ = pathlib.Path(os.path.abspath(__file__))
    sys.path.append(str(__dir__))
    sys.path.append(str(__dir__.parent))
    pth_img = './ICPR_text_train_part2_20180313/image_10000/'
    pth_txt = './ICPR_text_train_part2_20180313/txt_10000/'
    pth_new_tmp='./ICPR_text_train_part2_20180313/img/'
    pth_new_tmp1 = './ICPR_text_train_part2_20180313/gt/'
    pth_new = './ICPR_text_train_part2_20180313/test'
    if not os.path.exists(pth_new):
        os.mkdir(pth_new)
    #img_test_pth = os.path.join(pth_new,'img')
    gt_test_pth = os.path.join(pth_new,'gt')
    # if not os.path.exists(img_test_pth):
    #     os.mkdir(img_test_pth)
    # if not os.path.exists(gt_test_pth):
    #     os.mkdir(gt_test_pth)
    pth_1 = './ICPR_text_train_part2_20180313/train'
    if not os.path.exists(pth_1):
        os.mkdir(pth_1)
    img_train_pth = os.path.join(pth_1,'img')
    gt_train_pth = os.path.join(pth_1,'gt')
    if not os.path.exists(img_train_pth):
        os.mkdir(img_train_pth)
    if not os.path.exists(gt_train_pth):
        os.mkdir(gt_train_pth)
    # gif2jpg(pth)
    files = [img for img in os.listdir(pth_img) if img.endswith('jpg')]
    train, test = train_test_split(files,test_size=0.3,random_state=2020)
    #train:img gt    test: img gt
    print('train:{} images,test:{} images'.format(len(train),len(test)))
    i=0
    for line in train:
        ori_pth_img = pth_img+line
        line_txt = os.path.splitext(line)[0]+'.txt'
        ori_pth_txt = pth_txt+line_txt
        des_pth_img = os.path.join(img_train_pth,line)
        des_pth_txt = os.path.join(gt_train_pth,line_txt)
        shutil.move(ori_pth_img,des_pth_img)
        shutil.move(ori_pth_txt,des_pth_txt)
        i=i+1
    print('move {} imgs totally'.format(i))
    os.rename(pth_img,pth_new_tmp)
    os.rename(pth_txt,pth_new_tmp1)
    shutil.move(pth_new_tmp,pth_new)
    shutil.move(pth_new_tmp1,pth_new)

 

posted @ 2021-08-16 14:56  猪大大BiuBiuBiu  阅读(114)  评论(0编辑  收藏  举报