如何在python中用插值绘制精确召回曲线?

2024-06-07 01:14:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我用sklearnprecision_recall_curve函数和matplotlib包绘制了一个精确的召回曲线。对于那些熟悉精确回忆曲线的人来说,你知道一些科学界只在插值时接受它,类似于这个例子here。现在我的问题是,你们中是否有人知道如何在python中进行插值?我已经找了一段时间的解决方案,但没有成功!任何帮助都将不胜感激。

解决方案:弗朗西斯和阿里的两个解决方案都是正确的,共同解决了我的问题。因此,假设您从sklearn中的precision_recall_curve函数获得一个输出,下面是我为绘制图形所做的:

        precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),scores.ravel())
        pr = copy.deepcopy(precision[0])
        rec = copy.deepcopy(recall[0])
        prInv = np.fliplr([pr])[0]
        recInv = np.fliplr([rec])[0]
        j = rec.shape[0]-2
        while j>=0:
            if prInv[j+1]>prInv[j]:
                prInv[j]=prInv[j+1]
            j=j-1
        decreasing_max_precision = np.maximum.accumulate(prInv[::-1])[::-1]
        plt.plot(recInv, decreasing_max_precision, marker= markers[mcounter], label=methodNames[countOfMethods]+': AUC={0:0.2f}'.format(average_precision[0]))

如果你把插值曲线放在for循环中,并在每次迭代时将每种方法的数据传递给它,这些线就会绘制插值曲线。请注意,这不会绘制非插值精度召回曲线。


Tags: 函数np绘制pr解决方案micro曲线precision
2条回答

@francis的解决方案可以使用^{}矢量化。

import numpy as np
import matplotlib.pyplot as plt

recall = np.linspace(0.0, 1.0, num=42)
precision = np.random.rand(42)*(1.-recall)

# take a running maximum over the reversed vector of precision values, reverse the
# result to match the order of the recall vector
decreasing_max_precision = np.maximum.accumulate(precision[::-1])[::-1]

也可以使用^{}来摆脱用于绘制的for循环:

fig, ax = plt.subplots(1, 1)
ax.hold(True)
ax.plot(recall, precision, '--b')
ax.step(recall, decreasing_max_precision, '-r')

enter image description here

可以执行反向迭代来删除precision中增加的部分。然后,根据Bennett Brown对vertical & horizontal lines in matplotlib的回答,可以绘制垂直线和水平线。

下面是一个示例代码:

import numpy as np
import matplotlib.pyplot as plt

#just a dummy sample
recall=np.linspace(0.0,1.0,num=42)
precision=np.random.rand(42)*(1.-recall)
precision2=precision.copy()
i=recall.shape[0]-2

# interpolation...
while i>=0:
    if precision[i+1]>precision[i]:
        precision[i]=precision[i+1]
    i=i-1

# plotting...
fig, ax = plt.subplots()
for i in range(recall.shape[0]-1):
    ax.plot((recall[i],recall[i]),(precision[i],precision[i+1]),'k-',label='',color='red') #vertical
    ax.plot((recall[i],recall[i+1]),(precision[i+1],precision[i+1]),'k-',label='',color='red') #horizontal

ax.plot(recall,precision2,'k--',color='blue')
#ax.legend()
ax.set_xlabel("recall")
ax.set_ylabel("precision")
plt.savefig('fig.jpg')
fig.show()

结果是:

enter image description here

相关问题 更多 >