按索引过滤并在numpy中变平,比如tf.序列

2024-05-12 22:11:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我想用一个索引过滤我的数组2D,然后只用过滤器中的值来调整这个数组。这差不多是什么tf.sequence_掩码可以,但我需要这个在纽比或其他光图书馆。在

谢谢!在

警察局: 这是一个例子:

array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]

Tags: numpy过滤器outputlen图书馆istf数组
2条回答

这是一种使用布尔蒙版并将其应用于展平的array_2d的方法

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

# Create a boolean mask
mask = np.zeros((array_2d.shape), dtype=bool)

# Change to True for elements to be kept
for i, j in enumerate(array_len):
        mask[i][0:j] = True

expected_output = array_2d.flatten()[mask.flatten()]

输出

^{pr2}$

下面是一个vectorized解决方案,使用布尔掩码索引array_2d

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

详细信息

创建掩码时将^{}放在与array_2d形状相同的^{}上,并执行行比较以查看哪些元素大于array_len。在

因此,第一步是创建以下ndarray

^{pr2}$

并与array_len执行行比较:

~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True, False],
       [ True,  True,  True, False, False, False]])

然后您只需使用以下内容筛选数组:

array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

相关问题 更多 >