机器学习笔记:sklearn交叉验证之KFold与StratifiedKFold

一、交叉验证

机器学习中常用交叉验证函数:KFoldStratifiedKFold

方法导入:

from sklearn.model_selection import KFold, StratifiedKFold
  • StratifiedKFold:采用分层划分的方法(分层随机抽样思想),验证集中不同类别占比与原始样本的比例一致,划分时需传入标签特征
  • KFold:默认随机划分训练集、验证集

二、KFold交叉验证

1.使用语法

sklearn.model_selection.KFold(n_splits=3, # 最少2折
                             shuffle=False, # 是否打乱
                             random_state=None)

2.实操

  • get_n_splits -- 返回折数
  • split -- 切分
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold

X = np.array([[1,2], [3,4], [5,6], [7,8]])
y = np.array([1,2,3,4])
kf = KFold(n_splits=2)
kf.get_n_splits() # 2
print(kf) # KFold(n_splits=2, random_state=None, shuffle=False)

# 此处的split只需传入数据,不需要传入标签
for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
'''
TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]
'''

三、StratifiedKFold交叉验证

1.使用语法

sklearn.model_selection.StratifiedKFold(n_splits=3, # 同KFold参数
                                       shuffle=False,
                                       random_state=None)

2.实操

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold

X = np.array([[1,2], [3,4], [5,6], [7,8]])
y = np.array([1,0,0,1])
skf = StratifiedKFold(n_splits=2)
skf.get_n_splits() # 2
print(skf) # StratifiedKFold(n_splits=2, random_state=None, shuffle=False)

# 同时传入数据集和标签
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

注意:拆分的折数必须大于等于标签类别,否则报错:

ValueError: n_splits=2 cannot be greater than the number of members in each class.

参考链接:sklearn.model_selection.KFold

参考链接:sklearn.model_selection.StratifiedKFold

参考链接:python sklearn中KFold与StratifiedKFold

posted @   Hider1214  阅读(2009)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示