在3D numpy阵列的小切片上高效地使用1D pyfftw

2024-05-14 03:19:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个大小为10000x512x512的三维数据立方体。我想沿着dim[0]反复解析一个向量窗口(比如6),并高效地生成fourier变换。我想我正在做一个数组拷贝到pyfftw包中,这给了我巨大的开销。我现在正在看文档,因为我认为有一个选项我需要设置,但我可以使用一些额外的语法帮助。你知道吗

这段代码最初是由另一个numpy.fft.rfft格式用麻木加速。但是这个实现在我的工作站上不起作用,所以我重新编写了所有内容,并选择使用pyfftw。你知道吗

import numpy as np
import pyfftw as ftw
from tkinter import simpledialog
from math import ceil
import multiprocessing

ftw.config.NUM_THREADS = multiprocessing.cpu_count()
ftw.interfaces.cache.enable()

def runme():
    # normally I would load a file, but for Stack Overflow, I'm just going to generate a 3D data cube so I'll delete references to the binary saving/loading functions:
    # load the file
    dataChunk = np.random.random((1000,512,512))
    numFrames = dataChunk.shape[0]
    # select the window size
    windowSize = int(simpledialog.askstring('Window Size',
        'How many frames to demodulate a single time point?'))
    numChannels = windowSize//2+1
    # create fftw arrays
    ftwIn = ftw.empty_aligned(windowSize, dtype='complex128')
    ftwOut = ftw.empty_aligned(windowSize, dtype='complex128')
    fftObject = ftw.FFTW(ftwIn,ftwOut)
    # perform DFT on the data chunk
    demodFrames = dataChunk.shape[0]//windowSize
    channelChunks = np.zeros([numChannels,demodFrames,
        dataChunk.shape[1],dataChunk.shape[2]])
    channelChunks = getDFT(dataChunk,channelChunks,
        ftwIn,ftwOut,fftObject,windowSize,numChannels)
    return channelChunks          

def getDFT(data,channelOut,ftwIn,ftwOut,fftObject,
        windowSize,numChannels):
    frameLen = data.shape[0]
    demodFrames = frameLen//windowSize
    for yy in range(data.shape[1]):
        for xx in range(data.shape[2]):
            index = 0
            for i in range(0,frameLen-windowSize+1,windowSize):
                ftwIn[:] = data[i:i+windowSize,yy,xx]
                fftObject() 
                channelOut[:,index,yy,xx] = 2*np.abs(ftwOut[:numChannels])/windowSize
                index+=1
    return channelOut

if __name__ == '__main__':
    runme()

我得到了一个4D数组;变量channelChunks。我将每个通道保存到一个二进制文件中(上面的代码中没有包含,但是保存部分可以正常工作)。你知道吗

这个过程是为了我们的解调项目,4D数据立方体通道块然后被解析成eval(numChannel)3D数据立方体(电影),从中我们能够根据我们的实验设置按颜色分离电影。我希望我能绕过写一个C++函数,通过Pyfftw调用矩阵上的FFT。你知道吗

实际上,我在给定的索引1和2轴上沿着dataChunk的0轴取windowSize=6个元素,并执行1D FFT。我需要在整个3D数据块中执行此操作,以生成解调后的电影。谢谢。你知道吗


Tags: the数据importfordatanpftwshape
1条回答
网友
1楼 · 发布于 2024-05-14 03:19:49

pyfftw可以自动生成FFTW advanced plans 可按以下方式修改代码:

  • 实到复变换可以用来代替复到复变换。 使用pyfftw,它通常会写入:

    ftwIn = ftw.empty_aligned(windowSize, dtype='float64')
    ftwOut = ftw.empty_aligned(windowSize//2+1, dtype='complex128')
    fftObject = ftw.FFTW(ftwIn,ftwOut)
    
  • 在FFTW规划器中添加几个标志。例如,FFTW_MEASURE将对不同的算法计时并选择最佳算法。FFTW_DESTROY_INPUT表示可以修改输入数组:可以使用一些实现技巧。你知道吗

    fftObject = ftw.FFTW(ftwIn,ftwOut, flags=('FFTW_MEASURE','FFTW_DESTROY_INPUT',))
    
  • 限制分区数。除法比乘法花费更多。你知道吗

    scale=1.0/windowSize
    for ...
        for ...
            2*np.abs(ftwOut[:,:,:])*scale  #instead of /windowSize
    
  • 通过pyfftw使用FFTW advanced plan避免多个for循环。

    nbwindow=numFrames//windowSize
    # create fftw arrays
    ftwIn = ftw.empty_aligned((nbwindow,windowSize,dataChunk.shape[2]), dtype='float64')
    ftwOut = ftw.empty_aligned((nbwindow,windowSize//2+1,dataChunk.shape[2]), dtype='complex128')
    fftObject = ftw.FFTW(ftwIn,ftwOut, axes=(1,), flags=('FFTW_MEASURE','FFTW_DESTROY_INPUT',))
    
    ...
    for yy in range(data.shape[1]):
        ftwIn[:] = np.reshape(data[0:nbwindow*windowSize,yy,:],(nbwindow,windowSize,data.shape[2]),order='C')
        fftObject()
        channelOut[:,:,yy,:]=np.transpose(2*np.abs(ftwOut[:,:,:])*scale, (1,0,2))
    

这是修改后的代码。我还将帧数减少到100,设置随机生成器的种子,以检查结果是否未被修改和注释tkinter。窗口的大小可以设置为2的幂次方,也可以设置为2、3、5或7的乘积,这样就可以有效地应用Cooley-Tuckey算法。避免大素数。你知道吗

import numpy as np
import pyfftw as ftw
#from tkinter import simpledialog
from math import ceil
import multiprocessing
import time


ftw.config.NUM_THREADS = multiprocessing.cpu_count()
ftw.interfaces.cache.enable()
ftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'

def runme():
    # normally I would load a file, but for Stack Overflow, I'm just going to generate a 3D data cube so I'll delete references to the binary saving/loading functions:
    # load the file
    np.random.seed(seed=42)
    dataChunk = np.random.random((100,512,512))
    numFrames = dataChunk.shape[0]
    # select the window size
    #windowSize = int(simpledialog.askstring('Window Size',
    #    'How many frames to demodulate a single time point?'))
    windowSize=32
    numChannels = windowSize//2+1

    nbwindow=numFrames//windowSize
    # create fftw arrays
    ftwIn = ftw.empty_aligned((nbwindow,windowSize,dataChunk.shape[2]), dtype='float64')
    ftwOut = ftw.empty_aligned((nbwindow,windowSize//2+1,dataChunk.shape[2]), dtype='complex128')

    #ftwIn = ftw.empty_aligned(windowSize, dtype='complex128')
    #ftwOut = ftw.empty_aligned(windowSize, dtype='complex128')
    fftObject = ftw.FFTW(ftwIn,ftwOut, axes=(1,), flags=('FFTW_MEASURE','FFTW_DESTROY_INPUT',))
    # perform DFT on the data chunk
    demodFrames = dataChunk.shape[0]//windowSize
    channelChunks = np.zeros([numChannels,demodFrames,
        dataChunk.shape[1],dataChunk.shape[2]])
    channelChunks = getDFT(dataChunk,channelChunks,
        ftwIn,ftwOut,fftObject,windowSize,numChannels)
    return channelChunks          

def getDFT(data,channelOut,ftwIn,ftwOut,fftObject,
        windowSize,numChannels):
    frameLen = data.shape[0]
    demodFrames = frameLen//windowSize
    printed=0
    nbwindow=data.shape[0]//windowSize
    scale=1.0/windowSize
    for yy in range(data.shape[1]):
        #for xx in range(data.shape[2]):
            index = 0

            ftwIn[:] = np.reshape(data[0:nbwindow*windowSize,yy,:],(nbwindow,windowSize,data.shape[2]),order='C')
            fftObject()
            channelOut[:,:,yy,:]=np.transpose(2*np.abs(ftwOut[:,:,:])*scale, (1,0,2))
            #for i in range(nbwindow):
                #channelOut[:,i,yy,xx] = 2*np.abs(ftwOut[i,:])*scale

            if printed==0:
                      for j in range(channelOut.shape[0]):
                          print j,channelOut[j,0,yy,0]
                      printed=1

    return channelOut

if __name__ == '__main__':
    seconds=time.time()
    runme()
    print "time: ", time.time()-seconds

让我们知道它有多快你的计算!我的电脑从24秒变成了不到2秒。。。你知道吗

相关问题 更多 >