如何从结构化 numpy 数组中删除列?

17 投票
3 回答
10498 浏览
提问于 2025-04-17 20:00

想象一下,你有一个结构化的numpy数组,这个数组是从一个csv文件生成的,第一行是字段名称。这个数组的样子是:

dtype([('A', '<f8'), ('B', '<f8'), ('C', '<f8'), ..., ('n','<f8'])

现在,假设你想从这个数组中删除第'i'列。有没有什么简单的方法可以做到这一点呢?

我希望它的工作方式像删除一样:

new_array = np.delete(old_array, 'i')

有没有什么好主意?

3 个回答

0

最简单的解决办法是使用内置的函数。

假设我们有一个叫做 points_array = np.array 的东西。这个 np.array 里面有很多列,其中一列是“类别”。

import numpy.lib.recfunctions as recfc

points_array = recfc.drop_fields(points_array, "classes", usemask=False)
7

我在网上查资料,看到Warren的回答后学到了我需要的东西,所以我忍不住想分享一个更简洁的版本,并且提供一个可以一次性高效删除多个字段的选项:

def rmfield( a, *fieldnames_to_remove ):
    return a[ [ name for name in a.dtype.names if name not in fieldnames_to_remove ] ]

示例:

a = rmfield(a, 'foo')
a = rmfield(a, 'foo', 'bar')  # remove multiple fields at once

如果我们真的要把代码写得更简短,下面的代码也是等效的:

rmfield=lambda a,*f:a[[n for n in a.dtype.names if n not in f]]
21

这不是一个简单的函数调用,但下面的内容展示了一种删除第 i 个字段的方法:

In [67]: a
Out[67]: 
array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], 
      dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])

In [68]: i = 1   # Drop the 'B' field

In [69]: names = list(a.dtype.names)

In [70]: names
Out[70]: ['A', 'B', 'C']

In [71]: new_names = names[:i] + names[i+1:]

In [72]: new_names
Out[72]: ['A', 'C']

In [73]: b = a[new_names]

In [74]: b
Out[74]: 
array([(1.0, 3.0), (4.0, 6.0)], 
      dtype=[('A', '<f8'), ('C', '<f8')])

把它封装成一个函数:

def remove_field_num(a, i):
    names = list(a.dtype.names)
    new_names = names[:i] + names[i+1:]
    b = a[new_names]
    return b

可能更自然的做法是根据给定的字段 名称 来删除字段:

def remove_field_name(a, name):
    names = list(a.dtype.names)
    if name in names:
        names.remove(name)
    b = a[names]
    return b

另外,可以看看 drop_rec_fields 函数,它是 matplotlib 的 mlab 模块的一部分。


更新:请查看我在 如何在不复制的情况下从结构化 numpy 数组中删除一列? 中的回答,了解如何创建结构化数组字段子集的视图,而不需要复制整个数组。

撰写回答