为什么在调用next()时StratifiedShuffleSplit返回完整数据集的训练/测试索引?

2024-06-17 12:12:55 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试对我的数据进行分层子样本,因为数据集相当大(+-100k个图像)。 我试图通过使用sciketlearn的StratifiedShuffleSplit类来变得聪明。The documentation提供了以下示例:

import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 0, 1, 1, 1])
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)

for train_index, test_index in sss.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

给出以下输出(各列车/试验分段的指数):

TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]

基于上述情况,并且由于分层shufflesplit的类型是一个生成器,我希望下面的代码(通过调用next())可以为我提供一个已创建的拆分。你知道吗

sss = StratifiedKFold(n_splits=10, random_state=0) 
train_index, test_index = next(sss.split(X, y)) #I expected this call to next would give me the indices of ONE of the (in this case 10) splits                         
print(type(sss.split(X,y)))                         #Type is generator

然而,当我检查len()之后,我看到我实际上得到了完整的数据集回来!有人能解释一下为什么会发生这种情况,以及我怎样才能达到分层子样本的目标吗?你知道吗

y_complete = np.concatenate((y[train_index], y[test_index]))            
X_complete = np.concatenate((X[train_index], X[test_index]))             
print(len(y_complete), len(X_complete)) #Gives me full length of dataset (So 99289 instead of expected 9920)

Tags: of数据testindex分层nptrainsplit
1条回答
网友
1楼 · 发布于 2024-06-17 12:12:55

这是您在创建的示例中看到的预期行为。如果你只分别看train_indextest_index,你会发现它们里面有一组互斥的索引。但是,如果您查看train_index + test_index中的串联索引集,组合的索引集将是完整的数据集本身。请参阅下面的代码以获得更清晰的信息:

ss = StratifiedKFold(n_splits=10, random_state=0)
split_gen = sss.split(X, y) # Store this generator in a variable 
train_index, test_index = next(split_gen)                         
print(type(sss.split(X,y)))

print("Length of Training split is {}".format(len(y[train_index])))
print("Indices are {}".format(train_index))
print("Actual data at those indices is {}".format(y[train_index]))

# Output : 
# Length of Training split is 3
# Indices are [5 2 3]
# Actual data at those indices is [1 0 1]

注意这里的train_index只包含3个索引,而不包含完整的数据集本身。类似的行为可以在test_index中看到:

print("Length of Test split is {}".format(len(y[test_index])))
print("Indices are {}".format(test_index))
print("Actual data at those indices is {}".format(y[test_index]))

# Output : 
# Length of Test split is 3
# Indices are [4 1 0]
# Actual data at those indices is [1 0 0]

您可以在这里看到[5 2 3][4 1 0]是互斥的,但是它们结合在一起形成了完整的数据集,这是在您使用上面的np.concatenate时发生的。你知道吗

要获得下一次拆分,请在generator对象上使用next

train_index, test_index = next(split_gen)
print("Length of Set 2 Training split is {}".format(len(y[train_index])))
print("Indices are {}".format(train_index))
print("Actual data at those indices is {}".format(y[train_index]))

# Length of Set 2 Training split is 3
# Indices are [5 1 4]
# Actual data at those indices is [1 0 1]

相关问题 更多 >