条件下numpy数组的串联

2024-04-23 10:43:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我想知道是否有一种简单的方法,根据(len a+b)中的一个布尔值沿着一个轴连接两个不同len(a和b)的数组

我已经可以使用python列表理解(根据布尔值)或创建多个新数组(2)来创建这个数组。你知道吗

简单示例:

import numpy as np

a = np.arange(3)
b = np.arange(3, 8)
a_b = np.arange(8) < 3

def np_or_not(a_b, a, b, axis=-1):
    axis = axis // len(a_b)
    aligned_a = np.take(a, np.cumsum(a_b) - 1, axis=axis)
    aligned_b = np.take(b, np.cumsum(~a_b) - 1, axis=axis)
    return np.where(a_b, aligned_a, aligned_b)

assert (np_or_not(a_b, a, b) == np.arange(8)).all()

N-D案例:

初始化

axis = 2
a = np.random.rand(10, 20, 4)
b = np.random.rand(10, 20, 9)
a_b = np.array([True] * a.shape[axis] + [False] * b.shape[axis])
np.random.shuffle(a_b)

支票(格式)

assert len(a.shape) == len(b.shape)
assert all(a.shape[ax] == b.shape[ax] for ax in range(len(a.shape)) if ax != axis)
assert len(a_b) == (a.shape[axis] + b.shape[axis])
assert sum(a_b) == a.shape[axis]
assert sum(~a_b) == b.shape[axis]

检查结果

result = np_or_not(a_b, a, b, axis=-1)

assert result.shape == tuple(l if ax != axis else len(a_b) for ax, l in enumerate(a.shape))

a_b_indexes = [0, 0]
for index, truth in enumerate(a_b):
    assert (result[..., index] == (a if truth else b)[..., a_b_indexes[1 - truth]]).all(), index
    a_b_indexes[1 - truth] += 1

编辑:感谢您的回复,我用concatenate替换了stacking,并提供了一个N-D案例示例。你知道吗


Tags: orforlennpnotrandom数组assert