利用Numba加速滚动均方根偏差的计算

2024-05-26 19:54:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图加快滚动均方根偏差的计算速度,我一直在使用一个半优化的代码。这是我的密码:


原始均方根偏差(当n很小时,此函数运行速度很快,但会受到较大n值的影响)

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def rt_mean_sq_dev(array1, array2, n):
    result = np.empty(array1.shape[0])
    result[:n-1] = np.nan

    for i in prange(n-1,array1.shape[0]):
        result[i] = np.sqrt(np.sum(np.square(array1[i+1-n:i+1] - array2[i]))/n)
    return result

半优化均方根偏差(无论n的大小,运行时不变)

@njit  
def rt_mean_sq_dev4(array1, array2, n):
    msd_temp = np.empty(array1.shape[0])

    K = array2[n-1]

    rs_x= array1[0] - K
    rs_xsq = rs_x *rs_x

    msd_temp[0] = np.nan

    for i in range(1,n):#first part of the cumsum for x adn xsq
        rs_x += array1[i] - K
        rs_xsq += np.square(array1[i] - K)
        msd_temp[i] = np.nan

    y_i = array2[n-1] - K
    msd_temp[n-1] = np.sqrt(max(y_i*y_i + (rs_xsq - 2*y_i*rs_x)/n, 0))

    for i in range(n, array1.shape[0]):
        rs_x = array1[i] - array1[i-n]+ rs_x
        rs_xsq = np.square(array1[i] - K) - np.square(array1[i-n] - K) + rs_xsq
        y_i = array2[i] - K

        msd_temp[i] = np.sqrt(max(y_i*y_i + (rs_xsq - 2*y_i*rs_x)/n, 0))

    return msd_temp 

if __name__ == '__main__':

    np.random.seed(0)
    data_size = 200000

    data_c = np.random.uniform(0,1000, size = data_size)+29000
    data_d = data_c  + np.random.uniform(0,1, size = data_size)


    tail = np.tile(30000,200)
    data_c = np.hstack((data_c, tail))
    data_d = np.hstack((data_d, tail))

    N = 3
    test4 = rt_mean_sq_dev4(data_c, data_d,  N )

n = 7000时,原始函数每个循环占用320ms,而rt_mean_sq_dev4仅占用2.6ms;然而,当n = 3时,原始函数占用1.3ms,此函数保持不变,大约2.5ms。所以总的来说rt_mean_sq_dev4更快,更受欢迎。你知道吗

rt_mean_sq_dev4是迄今为止我在这个问题上的最佳尝试,在for循环中使用了某种滚动累积结构。我曾尝试创建numpy数组来“存储”第二个循环中以前计算的一些平方值,以节省之前的一些计算时间,但令我惊讶的是,这种更改实际上使函数运行速度变慢。你知道吗

我不禁觉得,必须有一些更好的方法来实现我还不知道的计算;或者有一些小而重要的技术来进一步加速我还不知道的代码;或者有一些scipy函数正是我想做的,但速度更快,我只是不知道,这就是为什么我在这里寻求帮助。你知道吗


旁注:

1)在rt_mean_sq_dev4中的K是一个修正项,通过减少灾难性消除的影响来提高函数的精度;这一思想受到here的启发


EDIT1:忘记在我原来的函数rt_mean_sq_dev中键入np.sum()。这可能会引起任何读者对这个问题的困惑。我的道歉


Tags: 函数fordatasizenpsqdev4mean

热门问题