sklearn.model_selection.train_test_split
参考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
目的:将数组或矩阵分割为随机的训练和测试子集。
语法格式
sklearn.model_selection.train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)
几个常用参数解释:
- *arrays:允许输入列表、numpy数组、scipy-sparse矩阵或pandas数据框架。
- test_size: float or int。接受0-1之间数值,表示test数据集的占比;接受整数值,表示训练集的绝对数量;默认值为None,表示总数据集减train数据集的size;如果train_size也为None,将设置为0.25
- train_size: train数据集size。默认值为None,总数据集减test数据集的size
- random_state: 随机种子,随便给个整数就行。目的是每次执行生成的数据集一致,方便重复
- shuffle: 布尔值,表示是否在分割前对完整数据集顺序洗牌。默认值为True。如果shuffle=False,则stratify必须为None
- stratify: 默认值为None。如果不是None,则以分层方式分割数据,并将其用作类标签。
返回包括train-test的list, length=2 * len(arrays)
代码示例
<一>
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
d1 = [[3,"negative",2],[4,"negative",6],[11,"positive",0],[12,"positive",2],
[14,"positive",2],[3,"positive",5],[7,"negative",7],[7,"negative",6],
[1,"positive",2],[5,"negative",2]]
df1 = pd.DataFrame(d1, columns=["xuhao","result","value"])
print(df1,"\n")
# output
# xuhao result value
# 0 3 negative 2
# 1 4 negative 6
# 2 11 positive 0
# 3 12 positive 2
# 4 14 positive 2
# 5 3 positive 5
# 6 7 negative 7
# 7 7 negative 6
# 8 1 positive 2
# 9 5 negative 2
idx = np.arange(0, len(df1))
idx_train, idx_test = train_test_split(idx, test_size=0.6) #idx也可直接替换为df1
print(idx_train)
# [9 7 6 3]
print(idx_test,"\n")
# [2 4 8 5 1 0]
idx_train, idx_test = train_test_split(
idx, stratify=df1['result'], test_size=0.6
)
print(idx_train)
# [5 0 4 6]
print(idx_test)
# [3 7 1 2 8 9]
可以看到如果不加stratify选项,会存在抽取的train数据集对应的result列中三个样本都是negative;如果加上stratify=df1['result'],在提取train和test时,总是会按照result分类平均分配各样本。
<二>
将X,y分开输入进行数据分割
import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(12).reshape((4, 3)), range(4)
print(X)
# [[ 0 1 2]
# [ 3 4 5]
# [ 6 7 8]
# [ 9 10 11]]
print(list(y))
# [0, 1, 2, 3]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=111)
print(X_train)
# [[ 9 10 11]
# [ 0 1 2]]
print(y_train)
# [3, 0]