在numpy中查找浮点数组的唯一元素(使用差值比较)

31 投票
6 回答
23011 浏览
提问于 2025-04-16 14:22

我在使用numpy的时候,有一个包含浮点数的ndarray数组,我想找出这个数组中独特的值。由于浮点数的精度问题,这个过程会有一些麻烦……所以我想设置一个“差值”(delta)来帮助比较,找出哪些元素是独特的。

有没有办法做到这一点?目前我只是这样做:

unique(array)

这样做的结果大概是:

array([       -Inf,  0.62962963,  0.62962963,  0.62962963,  0.62962963,
    0.62962963])

在这里,看起来相同的值(根据显示的小数位数)实际上是有一点点不同的。

6 个回答

6

我刚注意到,接受的答案并不好用。比如这个例子:

a = 1-np.random.random(20)*0.05
<20 uniformly chosen values between 0.95 and 1.0>
np.sort(a)
>>>> array([ 0.9514548 ,  0.95172218,  0.95454535,  0.95482343,  0.95599525,
             0.95997008,  0.96385762,  0.96679186,  0.96873524,  0.97016127,
             0.97377579,  0.98407259,  0.98490461,  0.98964753,  0.9896733 ,
             0.99199411,  0.99261766,  0.99317258,  0.99420183,  0.99730928])
TOL = 0.01

结果是:

a.flat[i[d>TOL]]
>>>> array([], dtype=float64)

这是因为排序后的输入数组中的值之间的间距不够,不能至少相差“TOL”那么远,而正确的结果应该是:

>>>> array([ 0.9514548,  0.96385762,  0.97016127,  0.98407259,
             0.99199411])

(当然,这还要看你怎么决定在“TOL”范围内选择哪个值)

你应该利用整数不会受到机器精度影响这个事实:

np.unique(np.floor(a/TOL).astype(int))*TOL
>>>> array([ 0.95,  0.96,  0.97,  0.98,  0.99])

这个方法的速度比建议的解决方案快5倍(根据%timeit的测试结果)。

注意,“.astype(int)”是可选的,虽然去掉它会让性能下降1.5倍,因为从整数数组中提取唯一值要快得多。

你可能想要在唯一值的结果上加上“TOL”的一半,以补偿取整带来的影响:

(np.unique(np.floor(a/TOL).astype(int))+0.5)*TOL
>>>> array([ 0.955,  0.965,  0.975,  0.985,  0.995])
33

另一种可能性是将数值四舍五入到最接近的合适范围:

np.unique(a.round(decimals=4))

这里的 a 是你原来的数组。

补充: 需要说明的是,我的解决方案和 @unutbu 的在速度上几乎是一样的(我的可能快5%左右),所以两者都是不错的选择。

补充 #2: 这是为了回应保罗的担忧。确实会慢一些,可能还可以做一些优化,但我就这样发出来,目的是展示这个策略:

def eclose(a,b,rtol=1.0000000000000001e-05, atol=1e-08):
    return np.abs(a - b) <= (atol + rtol * np.abs(b))

x = np.array([6.4,6.500000001, 6.5,6.51])
y = x.flat.copy()
y.sort()
ci = 0

U = np.empty((0,),dtype=y.dtype)

while ci < y.size:
    ii = eclose(y[ci],y)
    mi = np.max(ii.nonzero())
    U = np.concatenate((U,[y[mi]])) 
    ci = mi + 1

print U

如果在精度范围内有很多重复的值,这个方法应该会比较快,但如果很多值都是独一无二的,那就会比较慢。此外,可能更好的是把 U 设置成一个列表,并通过 while 循环来添加,但这属于“进一步优化”的范畴。

14

难道 floorround 在某些情况下都不符合提问者的要求吗?

np.floor([5.99999999, 6.0]) # array([ 5.,  6.])
np.round([6.50000001, 6.5], 0) #array([ 7.,  6.])

我会这样做(虽然这可能不是最优的方案,也肯定比其他答案慢),大概是这样的:

import numpy as np
TOL = 1.0e-3
a = np.random.random((10,10))
i = np.argsort(a.flat)
d = np.append(True, np.diff(a.flat[i]))
result = a.flat[i[d>TOL]]

当然,这种方法会排除掉一系列值中除了最大的那个以外的所有值,只要这些值都在某个容忍范围内。这意味着如果数组中的所有值都非常接近,即使最大值和最小值的差距大于容忍范围,你也可能找不到任何独特的值。

这里有一个基本上相同的算法,但更容易理解,并且应该更快,因为它避免了索引步骤:

a = np.random.random((10,))
b = a.copy()
b.sort()
d = np.append(True, np.diff(b))
result = b[d>TOL]

提问者可能还想看看 scipy.cluster(这是这个方法的高级版本)或者 numpy.digitize(这是另外两种方法的高级版本)

撰写回答