scikit学习中的分层训练/测试划分

2024-04-18 06:00:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要把我的数据分成训练集(75%)和测试集(25%)。我现在使用下面的代码:

X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)   

不过,我想对我的培训数据集进行分层。我该怎么做?我一直在研究StratifiedKFold方法,但不允许我指定75%/25%的分割,只对训练数据集进行分层。


Tags: 数据方法代码test分层trainsklearnvalidation
3条回答

[更新0.17]

参见^{}的文档:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y, 
                                                    test_size=0.25)

[/0.17的更新]

有一个请求here。 但是你可以简单地做train, test = next(iter(StratifiedKFold(...))) 如果你想的话,可以用火车和测试指标。

TL;DR:将StratifiedShuffleSplittest_size=0.25一起使用

Scikit learn提供了两个用于分层拆分的模块:

  1. StratifiedKFold:这个模块作为一个直接的k-fold交叉验证操作符很有用:因为它将设置n_folds训练/测试集,这样类在这两个模块中都是相等的。

这里有一些代码(直接来自上面的文档)

>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
>>> len(skf)
2
>>> for train_index, test_index in skf:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
  1. StratifiedShuffleSplit:这个模块创建了一个单独的训练/测试集,其中包含均衡(分层)的类。本质上,这就是您想要的n_iter=1。您可以在这里提到与train_test_split中相同的测试大小

代码:

>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
>>> len(sss)
1
>>> for train_index, test_index in sss:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
>>> # fit and predict with your classifier using the above X/y train/test

下面是一个连续/回归数据的示例(直到this issue on GitHub被解析)。

# Your bins need to be appropriate for your output values
# e.g. 0 to 50 with 25 bins
bins     = np.linspace(0, 50, 25)
y_binned = np.digitize(y_full, bins)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y_binned)

相关问题 更多 >