numpy/scipy,循环遍历子数组

3 投票
2 回答
1866 浏览
提问于 2025-04-18 05:44

最近我在处理8x8的图像数据块,通常的方法是用嵌套的for循环来提取这些数据块,比如下面这样:

for y in xrange(0,height,8):
    for x in xrange(0,width,8):
        d = image_data[y:y+8,x:x+8]
        # further processing on the 8x8-block

我在想,是否有办法把这个操作变得更高效,或者用numpy/scipy中的其他方法来实现呢?比如说,使用某种迭代器?

一个简单的例子:

#!/usr/bin/env python

import sys
import numpy as np
from scipy.fftpack import dct, idct
import scipy.misc
import matplotlib.pyplot as plt

def dctdemo(coeffs=1):
    unzig = np.array([
         0,  1,  8, 16,  9,  2,  3, 10,
        17, 24, 32, 25, 18, 11,  4, 5,
        12, 19, 26, 33, 40, 48, 41, 34,
        27, 20, 13,  6,  7, 14, 21, 28,
        35, 42, 49, 56, 57, 50, 43, 36,
        29, 22, 15, 23, 30, 37, 44, 51,
        58, 59, 52, 45, 38, 31, 39, 46,
        53, 60, 61, 54, 47, 55, 62, 63])

    lena = scipy.misc.lena()
    width, height = lena.shape

    # reconstructed
    rec = np.zeros(lena.shape, dtype=np.int64)

    # Can this part be vectorized?
    for y in xrange(0,height,8):
        for x in xrange(0,width,8):
            d = lena[y:y+8,x:x+8].astype(np.float)
            D = dct(dct(d.T, norm='ortho').T, norm='ortho').reshape(64)
            Q = np.zeros(64, dtype=np.float)
            Q[unzig[:coeffs]] = D[unzig[:coeffs]]
            Q = Q.reshape([8,8])
            q = np.round(idct(idct(Q.T, norm='ortho').T, norm='ortho'))
            rec[y:y+8,x:x+8] = q.astype(np.int64)

    plt.imshow(rec, cmap='gray')
    plt.show()

if __name__ == '__main__':
    try:
        c = int(sys.argv[1])
    except ValueError:
        sys.exit()
    else:
        if 1 <= int(sys.argv[1]) <= 64:
            dctdemo(int(sys.argv[1]))

脚注:

  1. 实际应用: https://github.com/figgis/dctdemo

2 个回答

3

在scikit-learn的特征提取功能里,有一个叫做 extract_patches 的函数。你需要指定一个 patch_size(补丁大小)和一个 extraction_step(提取步长)。这个函数的结果会把你的图像分成多个小块,这些小块可能会有重叠。最终得到的数组是四维的,前两个维度表示补丁,后两个维度表示补丁里的像素。你可以试试这个

from sklearn.feature_extraction.image import extract_patches
patches = extract_patches(image_data, patch_size=(8, 8), extraction_step=(4, 4))

这样会得到大小为(8, 8)的补丁,并且这些补丁会有一半的重叠。

值得注意的是,到目前为止,这个过程没有额外占用内存,因为它是通过一种叫做步幅技巧的方式实现的。如果你想强制复制,可以通过重塑数组来实现

patches = patches.reshape(-1, 8, 8)

这样基本上就会得到一个补丁的列表。

5

在Scikit Image中,有一个叫做 view_as_windows 的函数可以用来实现这个功能。

很遗憾,我得等下次再完成这个回答,不过你可以用下面的方式获取可以传递给 dct 的窗口:

from skimage.util import view_as_windows
# your code...
d = view_as_windows(lena.astype(np.float), (8, 8)).reshape(-1, 8, 8)
dct(d, axis=0)

撰写回答