sklearn.model_selection.train_test_split

参考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

目的:将数组或矩阵分割为随机的训练和测试子集。

语法格式

sklearn.model_selection.train_test_split(*arraystest_size=Nonetrain_size=Nonerandom_state=Noneshuffle=Truestratify=None)

几个常用参数解释:

  • *arrays:允许输入列表、numpy数组、scipy-sparse矩阵或pandas数据框架。
  • test_size: float or int。接受0-1之间数值,表示test数据集的占比;接受整数值,表示训练集的绝对数量;默认值为None,表示总数据集减train数据集的size;如果train_size也为None,将设置为0.25
  • train_size: train数据集size。默认值为None,总数据集减test数据集的size
  • random_state: 随机种子,随便给个整数就行。目的是每次执行生成的数据集一致,方便重复
  • shuffle: 布尔值,表示是否在分割前对完整数据集顺序洗牌。默认值为True。如果shuffle=False,则stratify必须为None
  • stratify: 默认值为None。如果不是None,则以分层方式分割数据,并将其用作类标签。

返回包括train-test的list, length=2 * len(arrays)

代码示例

<一>

import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

d1 = [[3,"negative",2],[4,"negative",6],[11,"positive",0],[12,"positive",2],
[14,"positive",2],[3,"positive",5],[7,"negative",7],[7,"negative",6],
[1,"positive",2],[5,"negative",2]]
df1 = pd.DataFrame(d1, columns=["xuhao","result","value"])
print(df1,"\n")
# output
#    xuhao    result  value
# 0      3  negative      2
# 1      4  negative      6
# 2     11  positive      0
# 3     12  positive      2
# 4     14  positive      2
# 5      3  positive      5
# 6      7  negative      7
# 7      7  negative      6
# 8      1  positive      2
# 9      5  negative      2
idx = np.arange(0, len(df1))
idx_train, idx_test = train_test_split(idx, test_size=0.6) #idx也可直接替换为df1
print(idx_train)  
# [9 7 6 3]
print(idx_test,"\n")
# [2 4 8 5 1 0]
idx_train, idx_test = train_test_split(
	idx, stratify=df1['result'], test_size=0.6
	)
print(idx_train)
# [5 0 4 6]
print(idx_test)
# [3 7 1 2 8 9]

可以看到如果不加stratify选项,会存在抽取的train数据集对应的result列中三个样本都是negative;如果加上stratify=df1['result'],在提取train和test时,总是会按照result分类平均分配各样本。

<二>

将X,y分开输入进行数据分割

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(12).reshape((4, 3)), range(4)
print(X)
# [[ 0  1  2]
#  [ 3  4  5]
#  [ 6  7  8]
#  [ 9 10 11]]
print(list(y))
# [0, 1, 2, 3]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=111)
print(X_train)
# [[ 9 10 11]
#  [ 0  1  2]]
print(y_train)
# [3, 0]

 

 

 

posted @ 2023-03-18 15:09  yayagogogo  阅读(25)  评论(0编辑  收藏  举报