如何计算对象关键点相似度

2024-06-16 13:56:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图计算对象关键点相似性来评估算法的关键点检测。下面是我根据从here中发现和理解的内容编写的代码

oks formula

def oks(gt, preds, threshold, v, gt_area):

ious = np.zeros((len(preds), len(gt)))
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
vars = (sigmas*2)**2
k = len(sigmas)

xg = gt[:, 0]; yg = gt[:, 1]
xp = preds[:, 0]; yp = preds[:, 1]
vg = v + 1 # add one to visibility tags
k1 = np.count_nonzero(vg > 0)
dx = np.subtract(xg, xp)
dy = np.subtract(yg, yp)

e = (dx**2+dy**2)/vars/(gt_area+np.spacing(1))/2
if threshold is not None:
    ind = list(vg > threshold)
    e = e[ind]
ious = np.sum(np.exp(-e))/(1.5*e.shape[0]) if len(e) != 0 else 0
return ious

在哪里,

gt、PRED是17x2 NumPy阵列,包含17(x,y)个人体姿势坐标,分别用于地面真实值和机器学习模型的预测

阈值=0.5(coco数据集使用0.5作为软阈值)

v=地面真值关键点(17x1 NumPy数组)的可见性,值为0=可见,1=遮挡(因此我们采用vg=v+1以符合oks公式)

gt_面积=地面真相人员边界框的面积

我的印象是oks应该为每个关键点生成一个值,但是上面的代码为所有的关键点组合生成一个值。我做错什么了吗


Tags: 代码gtthresholdlennpareavars关键点