使用Numpy以到输入矩阵的最短距离获取训练集中数据点的索引

2024-05-13 22:45:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我想构建一个函数npbatch(U,X),它将输入矩阵(U)中的数据点与训练矩阵(X)中的数据点进行比较,并获得X的索引,该索引与U中的数据点的欧氏距离最短。 我希望避免任何循环以提高性能,并且我希望使用函数scipy.spatial.distance.cdist来计算距离

输入示例:

U
array([[0.69646919, 0.28613933, 0.22685145],
       [0.55131477, 0.71946897, 0.42310646],
       [0.9807642 , 0.68482974, 0.4809319 ]])

X
array([[0.24875591, 0.16306678, 0.78364326],
       [0.80852339, 0.62562843, 0.60411363],
       [0.8857019 , 0.75911747, 0.18110506]])

——>;预期输出:具有X中数据点的三个索引的数组,该索引与U中三个数据点的距离最短

我的总体目标是使用我得到的索引获得相应数据点的标签。标签输入的示例如下:

Y
array([1, 0, 0])

谢谢你的提示


Tags: 数据函数gt距离示例矩阵scipy标签
1条回答
网友
1楼 · 发布于 2024-05-13 22:45:18

使用scipy.spatial.distance.cdist您已经为任务选择了一个非常适合的函数。要获得索引,我们只需沿轴0应用^{}(或轴1表示cdist(U, X)):

ix = numpy.argmin(scipy.spatial.distance.cdist(X, U), 0)

因此,获取标签是一件微不足道的事情:

Y[ix]

相关问题 更多 >