numpy - 如何进行数组外连接

4 投票
1 回答
7619 浏览
提问于 2025-04-18 05:34

我正在尝试把这三个数组合并成下面这个数组。基本上就像SQL中的外连接(其中'pos'字段是关键/index)。

a1 = array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a2 = array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('3:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a3 = array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])

想要的结果:

array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695, 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605, 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001, 1.9849650921836601e-31, 0.99999999997999001),
       ('3:6506', 4.6725971801473496e-25, 0.99999999995088695, NaN, NaN),
       ('3:6601', 2.2452745388799898e-27, 0.99999999995270605, NaN, NaN),
       ('3:21801', 1.9849650921836601e-31, 0.99999999997999001, NaN, NaN),
        ], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8'), ('col3', '<f8'), ('col4', '<f8')])

我觉得这个回答可能是个不错的方向,但我就是不太明白怎么用。

更新:

我尝试运行unutbu的回答,但出现了这个错误:

Traceback (most recent call last):
  File "fail2.py", line 21, in <module>
    a4 = recfunctions.join_by('pos', a4, a, jointype='outer')
  File "/usr/local/msg/lib/python2.6/site-packages/numpy/lib/recfunctions.py", line 973, in join_by
    current = output[f]
  File "/usr/local/msg/lib/python2.6/site-packages/numpy/ma/core.py", line 2943, in __getitem__
    dout = ndarray.__getitem__(_data, indx)
ValueError: field named col12 not found.

更新 2

我只在numpy 1.5.1版本上遇到这个错误。我升级到1.8.1后就没问题了。

1 个回答

6
import numpy as np
import numpy.lib.recfunctions as recfunctions

a1 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a2 = np.array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('3:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a3 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])

result = a1
for a in (a2, a3):
    cols = list(set(result.dtype.names).intersection(a.dtype.names))
    result = recfunctions.join_by(cols, result, a, jointype='outer')
print(result)

产生

[ ('2:21801', 1.98496509218366e-31, 0.99999999997999, 1.98496509218366e-31, 0.99999999997999)
 ('2:6506', 4.67259718014735e-25, 0.999999999950887, 4.67259718014735e-25, 0.999999999950887)
 ('2:6601', 2.24527453887999e-27, 0.999999999952706, 2.24527453887999e-27, 0.999999999952706)
 ('3:21801', 1.98496509218366e-31, 0.99999999997999, --, --)
 ('3:6506', 4.67259718014735e-25, 0.999999999950887, --, --)
 ('3:6601', 2.24527453887999e-27, 0.999999999952706, --, --)]

如果你在处理类似SQL的连接操作时使用NumPy数组,建议你考虑使用Pandas。Pandas是建立在NumPy基础上的,提供了更多的数据处理功能:

import numpy as np
import pandas as pd
a1 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a2 = np.array([('3:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('3:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('3:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col1', '<f8'), ('col2', '<f8')])

a3 = np.array([('2:6506', 4.6725971801473496e-25, 0.99999999995088695),
       ('2:6601', 2.2452745388799898e-27, 0.99999999995270605),
       ('2:21801', 1.9849650921836601e-31, 0.99999999997999001),], 
      dtype=[('pos', '|S100'), ('col3', '<f8'), ('col4', '<f8')])

dfs = [pd.DataFrame.from_records(a) for a in (a1, a2, a3)]

result = dfs[0]
for df in dfs[1:]:
    cols = list(set(result.columns).intersection(df.columns))
    result = pd.merge(result, df, how='outer', left_on=cols, right_on=cols)

print(result)

产生

       pos          col1  col2          col3  col4
0   2:6506  4.672597e-25     1  4.672597e-25     1
1   2:6601  2.245275e-27     1  2.245275e-27     1
2  2:21801  1.984965e-31     1  1.984965e-31     1
3   3:6506  4.672597e-25     1           NaN   NaN
4   3:6601  2.245275e-27     1           NaN   NaN
5  3:21801  1.984965e-31     1           NaN   NaN

[6 rows x 5 columns]

有时候,Pandas的速度可能比纯NumPy的解决方案稍慢。但这通常是因为Pandas提供了更全面的解决方案,可以正确处理一些特殊情况,比如NaN(缺失值)或重复的索引值,而这些问题可能在简单的NumPy解决方案中没有被考虑到。

另外要注意,Pandas的DataFrame有一个.values属性,可以返回底层数据的NumPy数组,还有一个.to_records方法,可以返回一个结构化数组。正如你在上面看到的,还有一个Dataframe.from_records构造函数,可以将结构化数组转换为DataFrame。因此,如果你真的需要,可以很方便地在DataFrame和NumPy数组之间切换。

所以我认为使用Pandas并没有真正的速度劣势,它提供的便利性应该能让你更轻松地进行数据分析。

撰写回答