numpy.unique 怎么消除重复列?

1 投票
2 回答
60 浏览
提问于 2025-04-12 17:16

我不太明白Numpy的unique函数在处理多维数组时是怎么工作的。更具体地说,我对unique的文档中关于axis参数的描述感到困惑:

https://numpy.org/doc/stable/reference/generated/numpy.unique.html

当指定一个轴时,沿着这个轴的子数组会被排序。这通过将指定的轴移动到数组的第一维来实现(把这个轴放到第一维,以保持其他轴的顺序),然后按照C语言的顺序将子数组展平。展平后的子数组会被视为一种结构化类型,每个元素都有一个标签,这样我们最终得到一个一维的结构化类型数组,可以像处理其他一维数组一样处理它。结果是,展平后的子数组会按照字典顺序排序,从第一个元素开始。

我已经多次阅读了上面提到的段落,但遗憾的是,我仍然无法清楚地理解这个过程,特别是我上面加粗的部分。举个例子,假设我们有以下的二维数组:

import numpy as np

myarray = np.array(
    [
        [ 1,  3,  7,  8,  3], 
        [-5,  0,  9,  2,  0], 
        [10, 11, 12, 85, 11]
    ]
)

如你所见,第二列和第五列的值{3, 0, 11}是重复的。如果我想用numpy.unique来去掉重复的列,我会运行以下代码:

np.unique(myarray, axis=1)

这会得到预期的结果:

array([[ 1,  3,  7,  8],
       [-5,  0,  9,  2],
       [10, 11, 12, 85]])

第五列确实被去掉了,因为它是第二列的重复项。所以从视觉上看,结果是可以理解的。然而,如果我阅读上面提到的文档,试图按照建议将选定的轴移动到数组的第一维,然后展平结果的子数组,我就无法理解Numpy是如何拆分和重新组织数组的结构,以得到最终结果的。

你能否提供一个基于上述文档的逐步说明,描述Numpy是如何得出这个结果的?

2 个回答

2

指定轴:1。这意味着我们关注的是数组的列,以便识别唯一的列。

# Conceptual representation, not an actual operation here
# Each column is considered as a single item:
[
 [ 1, -5, 10],
 [ 3,  0, 11],
 [ 7,  9, 12],
 [ 8,  2, 85],
 [ 3,  0, 11]  # Duplicate of the second column
]
  • 在这里,NumPy 实际上把每一列当作一个结构化的数组元素来看,也就是说,它是整体地看待这一列,而不是逐个元素地看。
  • NumPy 将这些“结构化类型”(我们的列)按字典顺序排序,以便高效地找到重复项。这就像比较字符串一样:[1, -5, 10] 和 [3, 0, 11] 是不同的,但它会注意到 [3, 0, 11] 出现了两次:
 - "1,-5,10"
 - "3,0,11"
 - "7,9,12"
 - "8,2,85"
 - "3,0,11"  # This is a duplicate when viewed as a "string"

在 numpy.unique 对数组进行 axis=1 操作后,你会得到:

array([[ 1,  3,  7,  8],
       [-5,  0,  9,  2],
       [10, 11, 12, 85]])
2

源代码的链接可以在文档页面找到。

https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L138-L320

基本步骤如下:

首先,把轴移动到开始的位置:

In [103]: ar = np.moveaxis(myarray, 1, 0)

然后,创建一个连续的二维数组:

In [104]: orig_shape, orig_dtype = ar.shape, ar.dtype
     ...: ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))
     ...: ar = np.ascontiguousarray(ar)    
In [105]: ar
Out[105]: 
array([[ 1, -5, 10],
       [ 3,  0, 11],
       [ 7,  9, 12],
       [ 8,  2, 85],
       [ 3,  0, 11]])

接着,定义一个复合数据类型,并根据这个类型创建一个结构化数组:

In [106]: dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])];dtype
Out[106]: [('f0', dtype('int32')), ('f1', dtype('int32')), ('f2', dtype('int32'))]

In [107]: consolidated = ar.view(dtype); consolidated
Out[107]: 
array([[(1, -5, 10)],
       [(3,  0, 11)],
       [(7,  9, 12)],
       [(8,  2, 85)],
       [(3,  0, 11)]], dtype=[('f0', '<i4'), ('f1', '<i4'), ('f2', '<i4')])

这个数组之后可以当作一维数组来处理:

In [108]: np.lib.arraysetops._unique1d(consolidated)
Out[108]: 
(array([(1, -5, 10), (3,  0, 11), (7,  9, 12), (8,  2, 85)],
       dtype=[('f0', '<i4'), ('f1', '<i4'), ('f2', '<i4')]),)

还有进一步的代码可以把这个“唯一”的数组重新调整回原来的数据类型和形状。

unique1d 中,关键是对数组进行排序(去掉大小为1的尾部维度):

In [109]: np.sort(consolidated[:,0])
Out[109]: 
array([(1, -5, 10), (3,  0, 11), (3,  0, 11), (7,  9, 12), (8,  2, 85)],
      dtype=[('f0', '<i4'), ('f1', '<i4'), ('f2', '<i4')])

然后查找相邻的重复项。

所以,主要的“难点”是构建一个可以作为一维数组排序的结构化数组视图。这个排序是按字典顺序进行的——先看第一个字段,然后是第二个,以此类推。

对每一行的字符串格式进行排序也能得到类似的结果,但结构化数组的版本更通用。

In [121]: x=[f'{row}' for row in ar];x
Out[121]: ['[ 1 -5 10]', '[ 3  0 11]', '[ 7  9 12]', '[ 8  2 85]', '[ 3  0 11]']

In [122]: x.sort();x
Out[122]: ['[ 1 -5 10]', '[ 3  0 11]', '[ 3  0 11]', '[ 7  9 12]', '[ 8  2 85]']

撰写回答