自定义交叉验证分割 sklearn
我正在尝试在sklearn中对一个数据集进行划分,以便进行交叉验证和网格搜索(GridSearch)。我想自己定义划分的方式,但网格搜索只接受内置的交叉验证方法。
不过,我不能使用内置的交叉验证方法,因为我需要确保某些组的示例在同一个折叠中。
比如说,如果我有这些示例:[A1, A2, A3, A4, A5, B1, B2, B3, C1, C2, C3, C4, .... , Z1, Z2, Z3]
我想进行交叉验证,确保每个组的示例[A, B, C...]只出现在一个折叠中。
也就是说,K1包含[D, E, G, J, K...],K2包含[A, C, L, M,...],K3包含[B, F, I,...]等等。
2 个回答
0
我知道这个问题已经很久了,但我也遇到过同样的问题。看起来很快就会有一个更新,让你可以解决这个问题:
13
这种情况通常可以通过 sklearn.cross_validation.LeaveOneLabelOut
来实现。你只需要构建一个标签向量,用来表示你的分组。也就是说,所有在 K1
组里的样本都标记为 1
,所有在 K2
组里的样本标记为 2
,依此类推。
下面是一个可以直接运行的示例,里面用的是假数据。重要的部分是创建 cv
对象的那一行,以及调用 cross_val_score
的那一行。
import numpy as np
n_features = 10
# Make some data
A = np.random.randn(3, n_features)
B = np.random.randn(5, n_features)
C = np.random.randn(4, n_features)
D = np.random.randn(7, n_features)
E = np.random.randn(9, n_features)
# Group it
K1 = np.concatenate([A, B])
K2 = np.concatenate([C, D])
K3 = E
data = np.concatenate([K1, K2, K3])
# Make some dummy prediction target
target = np.random.randn(len(data)) > 0
# Make the corresponding labels
labels = np.concatenate([[i] * len(K) for i, K in enumerate([K1, K2, K3])])
from sklearn.cross_validation import LeaveOneLabelOut, cross_val_score
cv = LeaveOneLabelOut(labels)
# Use some classifier in crossvalidation on data
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
scores = cross_val_score(lr, data, target, cv=cv)
不过,当然也有可能你会遇到想要完全手动定义分组的情况。在这种情况下,你需要创建一个可迭代的对象(比如一个 list
),里面包含一对对的 (train, test)
,通过索引来指明每个分组中哪些样本应该放入训练集和测试集。我们来看看这个:
# create train and test folds from our labels:
cv_by_hand = [(np.where(labels != label)[0], np.where(labels == label)[0])
for label in np.unique(labels)]
# We check this against our existing cv by converting the latter to a list
cv_to_list = list(cv)
print cv_by_hand
print cv_to_list
# Check equality
for (train1, test1), (train2, test2) in zip(cv_by_hand, cv_to_list):
assert (train1 == train2).all() and (test1 == test2).all()
# Use the created cv_by_hand in cross validation
scores2 = cross_val_score(lr, data, target, cv=cv_by_hand)
# assert equality again
assert (scores == scores2).all()