如何从结构化 numpy 数组中删除列?
想象一下,你有一个结构化的numpy数组,这个数组是从一个csv文件生成的,第一行是字段名称。这个数组的样子是:
dtype([('A', '<f8'), ('B', '<f8'), ('C', '<f8'), ..., ('n','<f8'])
现在,假设你想从这个数组中删除第'i'列。有没有什么简单的方法可以做到这一点呢?
我希望它的工作方式像删除一样:
new_array = np.delete(old_array, 'i')
有没有什么好主意?
3 个回答
0
最简单的解决办法是使用内置的函数。
假设我们有一个叫做 points_array = np.array
的东西。这个 np.array
里面有很多列,其中一列是“类别”。
import numpy.lib.recfunctions as recfc
points_array = recfc.drop_fields(points_array, "classes", usemask=False)
7
我在网上查资料,看到Warren的回答后学到了我需要的东西,所以我忍不住想分享一个更简洁的版本,并且提供一个可以一次性高效删除多个字段的选项:
def rmfield( a, *fieldnames_to_remove ):
return a[ [ name for name in a.dtype.names if name not in fieldnames_to_remove ] ]
示例:
a = rmfield(a, 'foo')
a = rmfield(a, 'foo', 'bar') # remove multiple fields at once
如果我们真的要把代码写得更简短,下面的代码也是等效的:
rmfield=lambda a,*f:a[[n for n in a.dtype.names if n not in f]]
21
这不是一个简单的函数调用,但下面的内容展示了一种删除第 i 个字段的方法:
In [67]: a
Out[67]:
array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)],
dtype=[('A', '<f8'), ('B', '<f8'), ('C', '<f8')])
In [68]: i = 1 # Drop the 'B' field
In [69]: names = list(a.dtype.names)
In [70]: names
Out[70]: ['A', 'B', 'C']
In [71]: new_names = names[:i] + names[i+1:]
In [72]: new_names
Out[72]: ['A', 'C']
In [73]: b = a[new_names]
In [74]: b
Out[74]:
array([(1.0, 3.0), (4.0, 6.0)],
dtype=[('A', '<f8'), ('C', '<f8')])
把它封装成一个函数:
def remove_field_num(a, i):
names = list(a.dtype.names)
new_names = names[:i] + names[i+1:]
b = a[new_names]
return b
可能更自然的做法是根据给定的字段 名称 来删除字段:
def remove_field_name(a, name):
names = list(a.dtype.names)
if name in names:
names.remove(name)
b = a[names]
return b
另外,可以看看 drop_rec_fields
函数,它是 matplotlib 的 mlab
模块的一部分。
更新:请查看我在 如何在不复制的情况下从结构化 numpy 数组中删除一列? 中的回答,了解如何创建结构化数组字段子集的视图,而不需要复制整个数组。