KFold.split()等类似解释

 

class sklearn.model_selection.KFold(n_splits=5, *, shuffle=False, random_state=None

  

>>> import numpy as np
>>> from sklearn.model_selection import KFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([1, 2, 3, 4])
>>> kf = KFold(n_splits=2)
>>> kf.get_n_splits(X)
2
>>> print(kf)
KFold(n_splits=2, random_state=None, shuffle=False)
>>> for train_index, test_index in kf.split(X):
...     print("TRAIN:", train_index, "TEST:", test_index)
...     X_train, X_test = X[train_index], X[test_index]
...     y_train, y_test = y[train_index], y[test_index]
TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]

  

Methods

get_n_splits([X, y, groups])

Returns the number of splitting iterations in the cross-validator

split(X[, y, groups])

Generate indices to split data into training and test set.

 

for flod_idx, (train_idx, val_idx) in enumerate(skf.split(train_jpg, train_jpg)):
	...

  

 

话不多说,用例子说话:

from sklearn.model_selection import KFold
 
kf = KFold(n_splits=5, random_state=43, shuffle=True)
a=[[1,2],[3,4],[5,6],[7,8],[9,10]]
b=[1,2,3,4,5]
for i,j in kf.split(a,b):
    print(i,j)

#输出:
[0 1 2 4] [3]
[0 1 3 4] [2]
[0 2 3 4] [1]
[1 2 3 4] [0]
[0 1 2 3] [4]

不用多说了吧,上面的数都是索引,其实就是从0-4索引里,选一个作为输出,其他都是输入。

 

posted on 2022-10-14 13:53  lmqljt  阅读(182)  评论(0编辑  收藏  举报

导航