聚合数组的Pythonic方法(使用numpy或其他)
我想写一个很不错的函数,用来对一个数组里的数据进行汇总(这个数组是numpy的记录数组,但这并不影响)。
假设你有一个数据数组,想要在某个轴上进行汇总:比如说这个数组的格式是 dtype=[(name, (np.str_,8), (job, (np.str_,8), (income, np.uint32)]
,你想要计算每个工作的平均收入。
我写了这个函数,在例子中应该这样调用:aggregate(data,'job','income',mean)
def aggregate(data, key, value, func):
data_per_key = {}
for k,v in zip(data[key], data[value]):
if k not in data_per_key.keys():
data_per_key[k]=[]
data_per_key[k].append(v)
return [(k,func(data_per_key[k])) for k in data_per_key.keys()]
问题是我觉得这个函数不够简洁,我希望能把它写成一行:你有什么好主意吗?
谢谢你的回答,Louis
PS:我希望这个函数可以在调用时保持灵活,这样你也可以请求中位数、最小值等等。
6 个回答
2
更新 2022:
现在有一个软件包,它的功能跟matlab里的accumarray非常相似。你可以通过 pip install numpy_groupies
来安装它,或者你可以在这里找到它:
5
你的这段代码 if k not in data_per_key.keys()
可以简化成 if k not in data_per_key
,但你还可以用 defaultdict
来做得更好。下面是一个使用 defaultdict
的版本,这样就不需要检查键是否存在了:
import collections
def aggregate(data, key, value, func):
data_per_key = collections.defaultdict(list)
for k,v in zip(data[key], data[value]):
data_per_key[k].append(v)
return [(k,func(data_per_key[k])) for k in data_per_key.keys()]
5
也许你想找的功能是 matplotlib.mlab.rec_groupby:
import matplotlib.mlab
data=np.array(
[('Aaron','Digger',1),
('Bill','Planter',2),
('Carl','Waterer',3),
('Darlene','Planter',3),
('Earl','Digger',7)],
dtype=[('name', np.str_,8), ('job', np.str_,8), ('income', np.uint32)])
result=matplotlib.mlab.rec_groupby(data, ('job',), (('income',np.mean,'avg_income'),))
这个功能会产生
('Digger', 4.0)
('Planter', 2.5)
('Waterer', 3.0)
matplotlib.mlab.rec_groupby
会返回一个叫做 recarray 的东西:
print(result.dtype)
# [('job', '|S7'), ('avg_income', '<f8')]