Python - 获取三维数组的“子数组”

1 投票
1 回答
791 浏览
提问于 2025-04-30 13:54

我想从一个三维数组中获取多个子数组。在二维数组的情况下,我可以用一个在Stack上找到的函数来拆分数组:

def blockshaped(arr, nrows, ncols):
    h, w = arr.shape
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

所以我想把这个方法扩展到三维数组的情况,形成像二维数组那样的块,但在每个第一维的切片中。我尝试用“for循环”,但没有成功……

举个例子:

import numpy as np

#2D case (which works)

test=np.array([[ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 3.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]])

def blockshaped(arr, nrows, ncols): 

    h, w = arr.shape
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))


sub = blockshaped(test, 2,2)

我得到了4个“子数组”:

array([[[ 2.,  1.],
        [ 1.,  1.]],

       [[ 1.,  1.],
        [ 1.,  1.]],

       [[ 3.,  1.],
        [ 1.,  1.]],

       [[ 1.,  1.],
        [ 1.,  1.]]])

但是对于一个三维数组作为输入:

test2=np.array([[[ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 3.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]],

       [[ 5.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.],
        [ 2.,  1.,  1., 1.],
        [ 1.,  1.,  1., 1.]]])       

所以在这里我想要同样的分解,但在这2个“切片”中……

def blockshaped(arr, nrows, ncols): 

    h, w, t = arr.shape 
    return (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1,2)
               .reshape(-1, nrows, ncols))

我尝试用“for循环”,但没有成功:

for i in range(test2.shape[0]):                     
    sub = blockshaped(test[i,:,:], 2, 2)
暂无标签

1 个回答

1

你的 for 循环解决方案可以这样做:

sub = np.array([blockshaped(a, 2, 2) for a in test2])

不过你可以稍微修改一下 blockshaped(),在切片之前和之后对数据进行重新调整:

def blockshaped(arr, nrows, ncols):
    need_reshape = False
    if arr.ndim > 2:
        need_reshape = True
    if need_reshape:
        orig_shape = arr.shape
        arr = arr.reshape(-1, arr.shape[-1])
    h, w = arr.shape
    out = (arr.reshape(h//nrows, nrows, -1, ncols)
               .swapaxes(1, 2)
               .reshape(-1, nrows, ncols))
    if need_reshape:
        new_shape = list(out.shape)
        new_shape[0] //= orig_shape[0]
        out = out.reshape([-1,] + new_shape)
    return out

撰写回答