numpy/scipy,循环遍历子数组
最近我在处理8x8的图像数据块,通常的方法是用嵌套的for循环来提取这些数据块,比如下面这样:
for y in xrange(0,height,8):
for x in xrange(0,width,8):
d = image_data[y:y+8,x:x+8]
# further processing on the 8x8-block
我在想,是否有办法把这个操作变得更高效,或者用numpy/scipy中的其他方法来实现呢?比如说,使用某种迭代器?
一个简单的例子:
#!/usr/bin/env python
import sys
import numpy as np
from scipy.fftpack import dct, idct
import scipy.misc
import matplotlib.pyplot as plt
def dctdemo(coeffs=1):
unzig = np.array([
0, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63])
lena = scipy.misc.lena()
width, height = lena.shape
# reconstructed
rec = np.zeros(lena.shape, dtype=np.int64)
# Can this part be vectorized?
for y in xrange(0,height,8):
for x in xrange(0,width,8):
d = lena[y:y+8,x:x+8].astype(np.float)
D = dct(dct(d.T, norm='ortho').T, norm='ortho').reshape(64)
Q = np.zeros(64, dtype=np.float)
Q[unzig[:coeffs]] = D[unzig[:coeffs]]
Q = Q.reshape([8,8])
q = np.round(idct(idct(Q.T, norm='ortho').T, norm='ortho'))
rec[y:y+8,x:x+8] = q.astype(np.int64)
plt.imshow(rec, cmap='gray')
plt.show()
if __name__ == '__main__':
try:
c = int(sys.argv[1])
except ValueError:
sys.exit()
else:
if 1 <= int(sys.argv[1]) <= 64:
dctdemo(int(sys.argv[1]))
脚注:
2 个回答
3
在scikit-learn的特征提取功能里,有一个叫做 extract_patches
的函数。你需要指定一个 patch_size
(补丁大小)和一个 extraction_step
(提取步长)。这个函数的结果会把你的图像分成多个小块,这些小块可能会有重叠。最终得到的数组是四维的,前两个维度表示补丁,后两个维度表示补丁里的像素。你可以试试这个
from sklearn.feature_extraction.image import extract_patches
patches = extract_patches(image_data, patch_size=(8, 8), extraction_step=(4, 4))
这样会得到大小为(8, 8)的补丁,并且这些补丁会有一半的重叠。
值得注意的是,到目前为止,这个过程没有额外占用内存,因为它是通过一种叫做步幅技巧的方式实现的。如果你想强制复制,可以通过重塑数组来实现
patches = patches.reshape(-1, 8, 8)
这样基本上就会得到一个补丁的列表。
5
在Scikit Image中,有一个叫做 view_as_windows
的函数可以用来实现这个功能。
很遗憾,我得等下次再完成这个回答,不过你可以用下面的方式获取可以传递给 dct
的窗口:
from skimage.util import view_as_windows
# your code...
d = view_as_windows(lena.astype(np.float), (8, 8)).reshape(-1, 8, 8)
dct(d, axis=0)