加速遍历Numpy数组

7 投票
3 回答
4796 浏览
提问于 2025-04-16 21:17

我正在用Numpy进行图像处理,具体来说是进行一个运行标准差拉伸。这个过程会读取一定数量的列,计算标准差,然后进行百分比线性拉伸。接着,它会继续处理下一组列,重复相同的操作。输入的图像是一个1GB、32位、单波段的栅格图像,处理起来非常耗时(几个小时)。下面是代码。

我意识到我有3个嵌套的for循环,这可能是导致处理速度慢的原因。如果我把图像分成“块”来处理,比如加载一个[500,500]的数组,处理时间会短很多。不幸的是,由于相机的错误,我必须以非常长的条带(52,000 x 4)(y,x)来处理,以避免出现条纹。

如果有任何加速处理的建议,我将非常感激:

def box(dataset, outdataset, sampleSize, n):

    quiet = 0
    sample = sampleSize
    #iterate over all of the bands
    for j in xrange(1, dataset.RasterCount + 1): #1 based counter

        band = dataset.GetRasterBand(j)
        NDV = band.GetNoDataValue()

        print "Processing band: " + str(j)       

        #define the interval at which blocks are created
        intervalY = int(band.YSize/1)    
        intervalX = int(band.XSize/2000) #to be changed to sampleSize when working

        #iterate through the rows
        scanBlockCounter = 0

        for i in xrange(0,band.YSize,intervalY):

            #If the next i is going to fail due to the edge of the image/array
            if i + (intervalY*2) < band.YSize:
                numberRows = intervalY
            else:
                numberRows = band.YSize - i

            for h in xrange(0,band.XSize, intervalX):

                if h + (intervalX*2) < band.XSize:
                    numberColumns = intervalX
                else:
                    numberColumns = band.XSize - h

                scanBlock = band.ReadAsArray(h,i,numberColumns, numberRows).astype(numpy.float)

                standardDeviation = numpy.std(scanBlock)
                mean = numpy.mean(scanBlock)

                newMin = mean - (standardDeviation * n)
                newMax = mean + (standardDeviation * n)

                outputBlock = ((scanBlock - newMin)/(newMax-newMin))*255
                outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset


                scanBlockCounter = scanBlockCounter + 1
                #print str(scanBlockCounter) + ": " + str(scanBlock.shape) + str(h)+ ", " + str(intervalX)
                if numberColumns == band.XSize - h:
                    break

                #update progress line
                if not quiet:
                    gdal.TermProgress_nocb( (float(h+1) / band.YSize) )

这是一个更新: 在没有使用profile模块的情况下,因为我不想把小段代码封装成函数,所以我用了打印和退出语句来大致了解哪些行耗时最多。幸运的是(我知道我真的很幸运),有一行代码拖慢了整个过程。

    outRaster = outdataset.GetRasterBand(j).WriteArray(outputBlock,h,i)#array, xOffset, yOffset

看起来GDAL在打开输出文件和写出数组时效率很低。考虑到这一点,我决定把我修改后的数组“outBlock”添加到一个Python列表中,然后分块写出。这里是我修改的部分:

outputBlock只是被修改了...

         #Add the array to a list (tuple)
            outputArrayList.append(outputBlock)

            #Check the interval counter and if it is "time" write out the array
            if len(outputArrayList) >= (intervalX * writeSize) or finisher == 1:

                #Convert the tuple to a numpy array.  Here we horizontally stack the tuple of arrays.
                stacked = numpy.hstack(outputArrayList)

                #Write out the array
                outRaster = outdataset.GetRasterBand(j).WriteArray(stacked,xOffset,i)#array, xOffset, yOffset
                xOffset = xOffset + (intervalX*(intervalX * writeSize))

                #Cleanup to conserve memory
                outputArrayList = list()
                stacked = None
                finisher=0

Finisher只是一个处理边缘的标志。花了一些时间才搞清楚如何从列表中构建数组。在这个过程中,使用numpy.array会创建一个3维数组(有人能解释一下为什么吗?),而写入数组需要的是2维数组。现在总的处理时间从不到2分钟到5分钟不等。有人知道为什么会有这样的时间差吗?

非常感谢所有发帖的人!下一步是深入学习Numpy,了解向量化以进一步优化。

3 个回答

2

在不完全理解你在做什么的情况下,我注意到你没有使用任何 numpy 切片数组广播。这两种方法都可以让你的代码运行得更快,或者至少让代码更容易理解。如果这些和你的问题无关,我很抱歉。

7

加快对 numpy 数据操作的一种方法是使用 vectorize。简单来说,vectorize 会把一个函数 f 转换成一个新函数 g,这个新函数会对一个数组 a 进行操作。你可以像这样调用 gg(a)

>>> sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> sqrt_vec(numpy.arange(10))
array([ 0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ,
        2.23606798,  2.44948974,  2.64575131,  2.82842712,  3.        ])

我不能确定这是否会对你有帮助,因为我不知道你正在处理的数据是什么,但也许你可以把上面的内容改写成一组可以被 vectorized 的函数。也许在这种情况下,你可以对 ReadAsArray(h,i,numberColumns, numberRows) 的索引数组进行向量化。这里有一个潜在好处的例子:

>>> print setup1
import numpy
sqrt_vec = numpy.vectorize(lambda x: x ** 0.5)
>>> print setup2
import numpy
def sqrt_vec(a):
    r = numpy.zeros(len(a))
    for i in xrange(len(a)):
        r[i] = a[i] ** 0.5
    return r
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup1, number=1)
0.30318188667297363
>>> timeit.timeit(stmt='a = sqrt_vec(numpy.arange(1000000))', setup=setup2, number=1)
4.5400981903076172

速度提高了15倍!另外,注意到 numpy 的切片功能可以优雅地处理 ndarray 的边界:

>>> a = numpy.arange(25).reshape((5, 5))
>>> a[3:7, 3:7]
array([[18, 19],
       [23, 24]])

所以如果你能把你的 ReadAsArray 数据放到一个 ndarray 中,你就不需要做任何边界检查的麻烦事了。


关于你问的重塑数据的问题——重塑并不会从根本上改变数据。它只是改变了 numpy 索引数据时的“步幅”。当你调用 reshape 方法时,返回的值是数据的新视图;数据本身没有被复制或改变,旧的视图和旧的步幅信息也没有被改变。

>>> a = numpy.arange(25)
>>> b = a.reshape((5, 5))
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24])
>>> b
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])
>>> a[5]
5
>>> b[1][0]
5
>>> a[5] = 4792
>>> b[1][0]
4792
>>> a.strides
(8,)
>>> b.strides
(40, 8)
5

按照要求回答。

如果你的程序在输入输出(IO)方面比较慢,那你应该把读取和写入的数据分成小块来处理。比如说,可以先把大约500MB的数据放到一个数组里,处理完这些数据后再写出去,然后再读取下一块大约500MB的数据。记得重复使用那个数组,这样可以提高效率。

撰写回答