在非平衡数据上进行分类,精度很高,但精度很低

2024-05-16 22:17:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个不平衡的数据集,其中正类大约有10000个条目,负类大约有8000000个条目。我正在尝试一个简单的scikit的LogisticRegression模型作为基线模型,使用class_weight='balanced'(希望不平衡的问题应该得到解决?)。在

然而,我的准确度得分是0.83,但准确度得分是0.03。可能是什么问题?我需要分开处理不平衡的部分吗?在

这是我当前的代码:

>>> train = []
>>> target = []
>>> len(posList)
... 10214
>>> len(negList)
... 831134
>>> for entry in posList:
...     train.append(entry)
...     target.append(1)
...
>>> for entry in negList:
...     train.append(entry)
...     target.append(-1)
...
>>> train = np.array(train)
>>> target = np.array(target)
>>> 
>>> X_train, X_test, y_train, y_test = train_test_split(train, target, test_size=0.3, random_state=42)
>>> 
>>> model = LogisticRegression(class_weight='balanced')
>>> model.fit(X_train, y_train)
LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
          solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
>>> 
>>> predicted = model.predict(X_test)
>>> 
>>> metrics.accuracy_score(y_test, predicted)
0.835596671213
>>> 
>>> metrics.precision_score(y_test, predicted, average='weighted')
/usr/local/lib/python2.7/dist-packages/sklearn/metrics/classification.py:976: DeprecationWarning: From version 0.18, binary input will not be handled specially when using averaged precision/recall/F-score. Please use average='binary' to report only the positive class performance.
  'positive class performance.', DeprecationWarning)
0.033512518766

Tags: 模型testtargetmodel条目trainclassmetrics
1条回答
网友
1楼 · 发布于 2024-05-16 22:17:09

我想我明白发生了什么事。在

考虑一个虚拟分类器,它将为数据集的每个样本返回多数类。对于像你这样一个不平衡的集合来说,这似乎很公平(让我们称你的正类class 1和负类class 0)。分类器的精度为:831134/(831134+10214.0) = 0.987859958067292。是的,准确率为99%,但它不能很好地表示分类器。相反,我们最好看看它的精确度。因此,由于您的数据确实是不平衡的(比率1:80),Logistic回归性能较差,但对于虚拟分类器,它具有较高的精确度。在

精度定义为真阳性假阳性之和。换句话说,它是真正属于类别1的元素在被检测为属于类别1的所有元素中的比例。在

线性回归分类器的精度是acc = 0.835596671213。因此,它们是伪分类器和logistic回归的准确性差异:diff = 0.987859958067292 - 0.835596671213 = 0.15226328685429202。因此,15%的数据被伪分类器误分类,这几乎对应了n_misclass = 0.15*(831134+10214.0)=126202.2个样本。因此,Logistic回归将126202样本归类为来自类别1,而它们仅10214。在

Logistic回归的精度可能是:prec = 10214/126202.0 = 0.081。在

在你的例子中,它的准确度似乎不如我们上面看到的好。但这大概能给你一个线索,告诉你可能会发生什么。在

相关问题 更多 >