numpy数组中索引到标签的最快逆操作:标签到索引的字典(散列)

2024-04-24 19:29:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我发现自己总是在几个不同的场景中面对这个问题。所以我想在这里分享一下,看看有没有最佳的解决方法。在

假设我有一个很大的数组,它有一个与X大小相同的数组,叫做y,它上面有X所属的标签。像下面这样。在

X = np.array(['obect1', 'object2', 'object3', 'object4', 'object5'])
y = np.array([0, 1, 1, 0, 2])

我想要的是构建一个字典/散列,它使用一组标签作为键和所有对象的索引,其中X中的这些标签作为项。因此,在这种情况下,期望的输出将是:

^{pr2}$

请注意,实际上X上的内容并不重要,但为了完整起见,我将其包括在内。在

现在,我对这个问题的简单解决方案是遍历所有标签并使用np.where==label来构建字典。更详细地说,我使用这个函数:

def get_key_to_indexes_dic(labels):
    """
    Builds a dictionary whose keys are the labels and whose
    items are all the indexes that have that particular key
    """

    # Get the unique labels and initialize the dictionary
    label_set = set(labels)
    key_to_indexes = {}

    for label in label_set:
        key_to_indexes[label] = np.where(labels==label)

    return key_to_indexes

所以现在我的问题的核心是: 有没有办法做得更好?有没有一种自然的方法可以使用numpy函数来解决这个问题?我的方法是不是被误导了?在


作为一个不太重要的横向问题:上述定义中的解决方案的复杂性是什么?我认为解决方案的复杂性如下:

enter image description here

或者换句话说,标签的数量乘以在一个y大小的集合中使用np.where的复杂度加上从数组中建立集合的复杂性。这是对的吗?在

p.D.我找不到与这个具体问题相关的帖子,如果你有建议改变标题或任何我将不胜感激。在


Tags: theto方法keylabelsnp标签数组
3条回答

numpy_indexed包(免责声明:我是其作者)可用于以完全矢量化的方式解决此类问题,并且具有O(nlogn)最坏情况下的时间复杂性:

import numpy_indexed as npi
indices = np.arange(len(labels))
unique_labels, indices_per_label = npi.group_by(labels, indices)

请注意,对于此类功能的许多常见应用程序,例如计算组标签上的和或平均值,不计算指数的拆分列表,而是利用npi中的函数来进行比较,即,npi.分组依据(标签).mean(一些对应的_数组),而不是在每个标签上循环索引并取这些指数的平均值。在

假设标签是连续整数[0, m],取{},则{}的复杂度为O(n),循环中{}的复杂度为O(m*n)。但是,总体复杂度写为O(m*n)而不是O(m*n+n),参见"Big O notation" on wikipedia。在


有两件事可以提高性能:1)使用更高效的算法(较低的复杂性)和2)用快速数组操作替换Python循环。在

目前发布的其他答案正是这样做的,并且使用了非常合理的代码。然而,一个最优解既可以完全矢量化,又具有O(n)复杂度。这可以通过使用Scipy的某个较低级别函数来实现:

def sparse_hack(labels):
    from scipy.sparse._sparsetools import coo_tocsr

    labels = labels.ravel()
    n = len(labels)
    nlabels = np.max(labels) + 1

    indices = np.arange(n)
    sorted_indices = np.empty(n, int)
    offsets = np.zeros(nlabels+1, int)
    dummy = np.zeros(n, int)
    coo_tocsr(nlabels, 1, n, labels, dummy, indices, 
                             offsets, dummy, sorted_indices)

    return sorted_indices, offsets

coo_tocsr的源可以找到here。我使用它的方式是,它本质上执行一个间接的counting sort。老实说,这是一个相当模糊的方法,我建议你在其他答案中使用其中一种方法。在

如果在遍历过程中使用字典存储索引,则只需遍历一次:

from collections import defaultdict

def get_key_to_indexes_ddict(labels):
    indexes = defaultdict(list)
    for index, label in enumerate(labels):
        indexes[label].append(index)

缩放看起来很像您为您的选项分析过的,对于上面的函数,它是O(N),其中N是y的大小,因为检查字典中的值是否为O(1)。在

有趣的是,既然np.where的遍历速度要快得多,只要只有少量的标签,你的函数就会更快。当有许多不同的标签时,我的速度似乎更快。在

以下是函数的缩放方式:

scaling graph

蓝线是你的职能,红线是我的。线样式指示不同标签的数量。{10: ':', 100: '--', 1000: '-.', 10000: '-'}。您可以看到,我的函数相对独立于标签的数量,而您的函数在有很多标签时会很快变慢。如果你的商标不多,你最好用你的。在

相关问题 更多 >