在numpy数组中删除特定元素索引
我该怎么做才能从一个numpy数组的每一行中去掉除了几个特定的元素索引之外的所有元素呢?比如,下面这个例子中,想要删除每一行的第二个和第四个元素。
原始数组是:[[1,4,5,6,1], [1,4,5,1,3], [1,4,5,1,0], [1,6,2,6,9], [8,4,4,6,7]]
处理后的数组是:[[1,5,1], [1,5,3], [1,5,0], [1,2,9], [8,4,7]]
我尝试过用np.rot90(np.fliplr())来翻转数组,然后把每一行的元素添加到一个新数组中,最后再翻转回来,但这样做有点麻烦,因为我的数组非常大。
3 个回答
1
对于大数据集,你可以使用这种方法:
import numpy as np
# Your initial array
before = np.array([[1, 4, 5, 6, 1], [1, 4, 5, 1, 3], [1, 4, 5, 1, 0], [1, 6, 2, 6, 9], [8, 4, 4, 6, 7]])
# Indices of the elements you want to keep
indices_to_keep = [0, 2, 4]
# Use numpy's advanced indexing to select the desired elements
after = before[:, indices_to_keep]
print(after)
这个方法利用了NumPy的高级索引功能,而不是多次翻转数组,这样效率更高,特别是对于大数组来说。
1
你可以选择你想保留的那一列:
after = before[:, [0, 2, 4]]
[[1 5 1]
[1 5 3]
[1 5 0]
[1 2 9]
[8 4 7]]
1
在这个特定的情况下,你可以使用切片:
after = before[:,::2]
注意,这样做会创建一个新的 np.ndarray
对象,它是原始数据的一种视图。
一个更通用的方法可能是这样的:
after = np.delete(before, [1, 3], axis=1)
这样做就不应该是原始数据的视图了。