Python:如何解释和改进RandomForest中的predict_proba()

2024-06-09 21:55:06 发布

您现在位置:Python中文网/ 问答频道 /正文

因此,我使用sci工具包分类器将天文学数据分为三类。为了让我的问题更简单,我在测试集中只使用了两个来源,并获得了predict_prob()分数:

predictions = rf_model.predict(data_test)
probab =  rf_model.predict_proba(data_test)

print(probab)
print('True Classifications:', classif_test.values)
print('Predictions', predictions) 

给我以下信息:

[[0.29 0.69 0.02]
 [0.08 0.92 0.  ]]
True Classifications: ['HMXB' 'AGN']
Predictions ['HMXB' 'HMXB']

其中类顺序为[AGN, HMXB, SNR]。问题是这些预测中的一个是错误的,而另一个是正确的

我有几个问题。 (a) 我如何判断哪个predict_prob()分数对应于错误的预测? (b) predict_prob()到底在描述什么?模型的分类被认为是正确的可能性有多大? (b) 对于导致预测不准确的班级来说,高概率分数意味着什么?是我的数据集太小,还是有办法提高预测概率

所以对于我的数据,我有46 HMXB,17 AGN和3 SNR。每个源都有三个属性。我知道这是一个很小的数据集,但我想知道的是,对于随机森林或其他机器学习算法来说,它是否太小,无法给出准确的结果


Tags: 数据testtruedatamodelpredict分数print
1条回答
网友
1楼 · 发布于 2024-06-09 21:55:06

对于问题(b),predict_prob()到底描述了什么?
预测问题()将给出标签的概率
例如,如果您有三个标签['A'、'B'、'C'],并且predict_prob()给出[0.29,0.69,02],则表示该特定数据的结果有0.29变为'A'的概率,0.69变为'B'的概率,0.02变为'C'的概率。

对于问题(a),我如何判断哪一个predict_prob()分数对应于错误的预测?
从您发布的输出

[[0.29 0.69 0.02]
 [0.08 0.92 0.  ]]
Predictions ['HMXB' 'HMXB']

它清楚地表明每个列表中的第二项对应于“HMXB”。还有另外两种可能性(第一项和最后一项),我们需要查看数据并说明。

是的,你是说数据很小,而且很不平衡。因为与其他两个相比,“HMXB”有很多样本。尝试为其他标签获取更多样本

相关问题 更多 >