沿动态指定轴切片numpy数组

59 投票
7 回答
31000 浏览
提问于 2025-04-18 10:58

我想要在一个特定的方向上动态地切割一个numpy数组。给定这个:

axis = 2
start = 5
end = 10

我想要得到和这个一样的结果:

# m is some matrix
m[:,:,5:10]

使用类似这样的东西:

slc = tuple(:,) * len(m.shape)
slc[axis] = slice(start,end)
m[slc]

但是:的值不能放在一个元组里,所以我不知道该怎么构建这个切割。

7 个回答

8

有一种优雅的方法可以访问数组 x 的任意轴 n:使用 numpy.moveaxis¹ 将你感兴趣的轴移动到最前面。

x_move = np.moveaxis(x, n, 0)  # move n-th axis to front
x_move[start:end]              # access n-th axis

需要注意的是,你可能还需要对其他数组使用 moveaxis,这样才能保持和 x_move[start:end] 输出的一致性。因为数组 x_move 只是一个视图,所以你对它前面的轴所做的任何更改,都会影响到 xn 轴上的内容(也就是说,你可以对 x_move 进行读写操作)。


1) 你也可以使用 swapaxes,这样就不用担心 n0 的顺序,和 moveaxis(x, n, 0) 不同。我更喜欢 moveaxis,因为它只改变与 n 相关的顺序。

17

这可能有点晚了,但用Numpy的默认方法来做这个事情是numpy.take。不过,这个方法总是会复制数据(因为它支持复杂的索引,所以总是认为这是可能的)。为了避免这种情况(在很多情况下,你可能想要的是数据的视图,而不是复制),可以使用之前提到的slice(None)选项,最好把它封装成一个好用的函数:

def simple_slice(arr, inds, axis):
    # this does the same as np.take() except only supports simple slicing, not
    # advanced indexing, and thus is much faster
    sl = [slice(None)] * arr.ndim
    sl[axis] = inds
    return arr[tuple(sl)]
18

虽然我来得有点晚,但我有一个替代的切片函数,它的性能比其他答案中的稍微好一点:

def array_slice(a, axis, start, end, step=1):
    return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]

这里有一段代码用来测试每个答案。每个版本都标注了发布答案的用户名字:

import numpy as np
from timeit import timeit

def answer_dms(a, axis, start, end, step=1):
    slc = [slice(None)] * len(a.shape)
    slc[axis] = slice(start, end, step)
    return a[slc]

def answer_smiglo(a, axis, start, end, step=1):
    return a.take(indices=range(start, end, step), axis=axis)

def answer_eelkespaak(a, axis, start, end, step=1):
    sl = [slice(None)] * m.ndim
    sl[axis] = slice(start, end, step)
    return a[tuple(sl)]

def answer_clemisch(a, axis, start, end, step=1):
    a = np.moveaxis(a, axis, 0)
    a = a[start:end:step]
    return np.moveaxis(a, 0, axis)

def answer_leland(a, axis, start, end, step=1):
    return a[(slice(None),) * (axis % a.ndim) + (slice(start, end, step),)]

if __name__ == '__main__':
    m = np.arange(2*3*5).reshape((2,3,5))
    axis, start, end = 2, 1, 3
    target = m[:, :, 1:3]
    for answer in (answer_dms, answer_smiglo, answer_eelkespaak,
                   answer_clemisch, answer_leland):
        print(answer.__name__)
        m_copy = m.copy()
        m_slice = answer(m_copy, axis, start, end)
        c = np.allclose(target, m_slice)
        print('correct: %s' %c)
        t = timeit('answer(m, axis, start, end)',
                   setup='from __main__ import answer, m, axis, start, end')
        print('time:    %s' %t)
        try:
            m_slice[0,0,0] = 42
        except:
            print('method:  view_only')
        finally:
            if np.allclose(m, m_copy):
                print('method:  copy')
            else:
                print('method:  in_place')
        print('')

以下是结果:

answer_dms

Warning (from warnings module):
  File "C:\Users\leland.hepworth\test_dynamic_slicing.py", line 7
    return a[slc]
FutureWarning: Using a non-tuple sequence for multidimensional indexing is 
deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be 
interpreted as an array index, `arr[np.array(seq)]`, which will result either in an 
error or a different result.
correct: True
time:    2.2048302
method:  in_place

answer_smiglo
correct: True
time:    5.9013344
method:  copy

answer_eelkespaak
correct: True
time:    1.1219435999999998
method:  in_place

answer_clemisch
correct: True
time:    13.707583699999999
method:  in_place

answer_leland
correct: True
time:    0.9781496999999995
method:  in_place
  • DSM的答案在评论中提供了一些改进建议。
  • EelkeSpaak的答案应用了这些改进,避免了警告,并且速度更快。
  • Śmigło的答案使用了np.take,结果更差,虽然它不是只读的,但确实会创建一个副本。
  • clemisch的答案使用了np.moveaxis,完成时间最长,但意外的是,它会引用之前数组的内存位置。
  • 我的答案省去了中间切片列表的需要。当切片轴靠近开头时,它还使用了更短的切片索引。这使得结果最快,随着轴接近0还有额外的改进。

我还为每个版本添加了一个step参数,以防你需要这个功能。

68

因为没有说明得很清楚(我也在找这个):

与以下内容等价的是:

a = my_array[:, :, :, 8]
b = my_array[:, :, :, 2:7]

是:

a = my_array.take(indices=8, axis=3)
b = my_array.take(indices=range(2, 7), axis=3)
42

我觉得有一种方法可以用 slice(None) 来实现:

>>> m = np.arange(2*3*5).reshape((2,3,5))
>>> axis, start, end = 2, 1, 3
>>> target = m[:, :, 1:3]
>>> target
array([[[ 1,  2],
        [ 6,  7],
        [11, 12]],

       [[16, 17],
        [21, 22],
        [26, 27]]])
>>> slc = [slice(None)] * len(m.shape)
>>> slc[axis] = slice(start, end)
>>> np.allclose(m[slc], target)
True

我隐约记得我之前用过一个函数来做这个,但现在找不到了……

撰写回答