numpy - 如何进行数组外连接
我正在尝试把这三个数组合并成下面这个数组。基本上就像SQL中的外连接(其中'pos'字段是关键/index)。
a1 = array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a2 = array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
('3:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a3 = array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])
想要的结果:
array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695, 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605, 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001, 1.9849650921836601e-31, 0.99999999997999001),
('3:6506', 4.6725971801473496e-25, 0.99999999995088695, NaN, NaN),
('3:6601', 2.2452745388799898e-27, 0.99999999995270605, NaN, NaN),
('3:21801', 1.9849650921836601e-31, 0.99999999997999001, NaN, NaN),
],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8'), ('col3', '<f8'), ('col4', '<f8')])
我觉得这个回答可能是个不错的方向,但我就是不太明白怎么用。
更新:
我尝试运行unutbu的回答,但出现了这个错误:
Traceback (most recent call last):
File "fail2.py", line 21, in <module>
a4 = recfunctions.join_by('pos', a4, a, jointype='outer')
File "/usr/local/msg/lib/python2.6/site-packages/numpy/lib/recfunctions.py", line 973, in join_by
current = output[f]
File "/usr/local/msg/lib/python2.6/site-packages/numpy/ma/core.py", line 2943, in __getitem__
dout = ndarray.__getitem__(_data, indx)
ValueError: field named col12 not found.
更新 2
我只在numpy 1.5.1版本上遇到这个错误。我升级到1.8.1后就没问题了。
1 个回答
6
import numpy as np
import numpy.lib.recfunctions as recfunctions
a1 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a2 = np.array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
('3:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a3 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])
result = a1
for a in (a2, a3):
cols = list(set(result.dtype.names).intersection(a.dtype.names))
result = recfunctions.join_by(cols, result, a, jointype='outer')
print(result)
产生
[ ('2:21801', 1.98496509218366e-31, 0.99999999997999, 1.98496509218366e-31, 0.99999999997999)
('2:6506', 4.67259718014735e-25, 0.999999999950887, 4.67259718014735e-25, 0.999999999950887)
('2:6601', 2.24527453887999e-27, 0.999999999952706, 2.24527453887999e-27, 0.999999999952706)
('3:21801', 1.98496509218366e-31, 0.99999999997999, --, --)
('3:6506', 4.67259718014735e-25, 0.999999999950887, --, --)
('3:6601', 2.24527453887999e-27, 0.999999999952706, --, --)]
如果你在处理类似SQL的连接操作时使用NumPy数组,建议你考虑使用Pandas。Pandas是建立在NumPy基础上的,提供了更多的数据处理功能:
import numpy as np
import pandas as pd
a1 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a2 = np.array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
('3:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])
a3 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
('2:21801', 1.9849650921836601e-31, 0.99999999997999001),],
dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])
dfs = [pd.DataFrame.from_records(a) for a in (a1, a2, a3)]
result = dfs[0]
for df in dfs[1:]:
cols = list(set(result.columns).intersection(df.columns))
result = pd.merge(result, df, how='outer', left_on=cols, right_on=cols)
print(result)
产生
pos col1 col2 col3 col4
0 2:6506 4.672597e-25 1 4.672597e-25 1
1 2:6601 2.245275e-27 1 2.245275e-27 1
2 2:21801 1.984965e-31 1 1.984965e-31 1
3 3:6506 4.672597e-25 1 NaN NaN
4 3:6601 2.245275e-27 1 NaN NaN
5 3:21801 1.984965e-31 1 NaN NaN
[6 rows x 5 columns]
有时候,Pandas的速度可能比纯NumPy的解决方案稍慢。但这通常是因为Pandas提供了更全面的解决方案,可以正确处理一些特殊情况,比如NaN(缺失值)或重复的索引值,而这些问题可能在简单的NumPy解决方案中没有被考虑到。
另外要注意,Pandas的DataFrame有一个.values
属性,可以返回底层数据的NumPy数组,还有一个.to_records
方法,可以返回一个结构化数组。正如你在上面看到的,还有一个Dataframe.from_records
构造函数,可以将结构化数组转换为DataFrame。因此,如果你真的需要,可以很方便地在DataFrame和NumPy数组之间切换。
所以我认为使用Pandas并没有真正的速度劣势,它提供的便利性应该能让你更轻松地进行数据分析。