如何在googlecolab或任何其他基于ipython的环境中使KNN代码更快?

2024-04-20 12:58:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用googlecolaboratory对DonorsChoose数据集进行KNN分类。当我为avgw2v和tfidf数据集应用KNeighbors分类器时,下面的代码大约需要4小时才能执行。你知道吗

我已经试过在kaggle笔记本上运行它,但问题仍然存在。你知道吗

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import roc_auc_score
train_auc_set3 = []
cv_auc_set3 = []
K = [51, 101]
for i in tqdm(K):
    neigh = KNeighborsClassifier(n_neighbors=i, n_jobs=-1)
    neigh.fit(X_tr_set3, y_train)

    y_train_set3_pred = batch_predict(neigh, X_tr_set3)    
    y_cv_set3_pred = batch_predict(neigh, X_cr_set3)        
    train_auc_set3.append(roc_auc_score(y_train,y_train_set3_pred))
    cv_auc_set3.append(roc_auc_score(y_cv, y_cv_set3_pred))

plt.plot(K, train_auc_set3, label='Train AUC')
plt.plot(K, cv_auc_set3, label='CV AUC')

plt.scatter(K, train_auc_set3, label='Train AUC points')
plt.scatter(K, cv_auc_set3, label='CV AUC points')

plt.legend()
plt.xlabel("K: hyperparameter")
plt.ylabel("AUC")
plt.title("ERROR PLOTS")
plt.grid()
plt.show()

Tags: 数据fromimporttrainpltsklearncvlabel
1条回答
网友
1楼 · 发布于 2024-04-20 12:58:36

这可能天生就很慢。我对这个数据集不是很熟悉,但在Kaggle上看一眼,它似乎包含了超过400万个数据点。从KNN的sklearn页面:

For each iteration, time complexity is O(n_components x n_samples >x min(n_samples, n_features)).

还要记住,对于大型数据集,knn必须测量给定数据点和训练集中所有数据点之间的距离,以便进行预测,这在计算上非常昂贵。你知道吗

对于一个非常大的数据集,在k上使用大量的数字可能会得到非常差的性能。我可能会做的是:

1)查看使用单个值k拟合knn需要多少时间,以及使用单个值k对训练集进行预测需要多少时间。如果要花很长时间,那我猜这就是你的问题。你知道吗

不幸的是,有时对于非常大的数据集,我们在选择算法时受到我们可能希望使用的算法的时间复杂性的限制。例如,核岭回归是一种很好的算法,但由于需要寻找具有立方时间复杂度的矩阵逆,它不能很好地扩展到大型数据集。你知道吗

相关问题 更多 >