numpy中'ismember'的等效函数是什么?
我在找一个Numpy的用法,想要实现一个在Matlab中用到的“模式”,这个模式是用来处理ismember的。
可惜的是,这段代码通常是我在Matlab脚本中花费时间最多的地方,所以我想找到一个高效的Numpy替代方案。
这个基本的模式是将一部分数据映射到一个更大的网格上。我有一组键值对,存储在两个并行的数组中,我想把这些值插入到一个更大的键值对列表中,这个列表也是以相同的方式存储的。
为了更具体一点,假设我有季度GDP数据,我想把它映射到一个按月的时间网格上,像这样。
quarters = [200712 200803 200806 200809 200812 200903];
gdp_q = [10.1 10.5 11.1 11.8 10.9 10.3];
months = 200801 : 200812;
gdp_m = NaN(size(months));
[tf, loc] = ismember(quarters, months);
gdp_m(loc(tf)) = gdp_q(tf);
需要注意的是,并不是所有的季度都出现在月份的列表中,所以tf和loc这两个变量都是必需的。
我在StackOverflow上看到过类似的问题,但要么只是给出纯Python的解决方案(在这里),要么在使用Numpy时没有返回loc参数(在这里)。
在我特定的应用领域,这种代码模式经常出现,并且占用了我函数大部分的CPU时间,所以找到一个高效的解决方案对我来说非常重要。
欢迎大家提出意见或重新设计的建议。
3 个回答
可以试试从pypi上下载的 ismember
库。
pip install ismember
举个例子:
# Import library
from ismember import ismember
# data
quarters = np.array([200712, 200803, 200806, 200809, 200812, 200903])
months = np.arange(200801, 200812)
# Lookup
Iloc,idx=ismember(quarters, months)
# Iloc is boolean defining existence of quarters in months
print(Iloc)
# [False, True, True, True, False, False]
# index of months that exists in quarters
print(idx)
# [2, 5, 8]
print(months[idx])
[200803, 200806, 200809]
print(quarters[Iloc])
[200803, 200806, 200809]
# These vectors will match
quarters[Iloc]==months[idx]
我觉得你可以重新设计你提供的原始MATLAB代码,这样就不需要使用ISMEMBER这个函数了。这样可能会让MATLAB代码运行得更快,而且如果你想的话,转到Python实现起来也会更简单:
quarters = [200712 200803 200806 200809 200812 200903];
gdp_q = [10.1 10.5 11.1 11.8 10.9 10.3];
monthStart = 200801; %# Starting month value
monthEnd = 200812; %# Ending month value
nMonths = monthEnd-monthStart+1; %# Number of months
gdp_m = NaN(1,nMonths); %# Initialize gdp_m
quarters = quarters-monthStart+1; %# Shift quarter values so they can be
%# used as indices into gdp_m
index = (quarters >= 1) & (quarters <= nMonths); %# Logical index of quarters
%# within month range
gdp_m(quarters(index)) = gdp_q(index); %# Move values from gdp_q to gdp_m
如果你的月份数据是排好序的,可以直接使用 np.searchsorted
。如果没有排好序,先把它们排序,然后再用 np.searchsorted
:
import numpy as np
quarters = np.array([200712, 200803, 200806, 200809, 200812, 200903])
months = np.arange(200801, 200813)
loc = np.searchsorted(months, quarters)
np.searchsorted
会告诉你插入的位置。如果你的数据可能不在正确的范围内,建议你之后再检查一下:
valid = (quarters <= months.max()) & (quarters >= months.min())
loc = loc[valid]
这个方法的时间复杂度是 O(N log N)。如果在你的程序中,这个运行时间还是太长了,你可以考虑用 C(++) 语言写一个子程序,使用哈希方法,这样可以把时间复杂度降到 O(N)(当然,这样也能避免一些常数因素的影响)。