仅使用NumPy einsum处理上三角元素

2024-05-15 04:44:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我用numpy-einsum来计算一个列向量pts数组的点积,它的形状是(3,N),它本身是一个矩阵dotps,shape(N,N),以及所有的点积。这是我使用的代码:

dotps = np.einsum('ij,ik->jk', pts, pts)

这是可行的,但我只需要主对角线上方的值。即结果的上三角部分没有对角线。有没有可能只用einsum计算这些值?或者以任何比使用einsum计算整个矩阵更快的方式?在

我的pts数组可以相当大,所以如果我只计算我需要的值,我的计算速度将加倍。在


Tags: 代码numpynp矩阵数组向量ikpts
1条回答
网友
1楼 · 发布于 2024-05-15 04:44:12

您可以将相关列切片,然后使用^{}-

R,C = np.triu_indices(N,1)
out = np.einsum('ij,ij->j',pts[:,R],pts[:,C])

样本运行-

^{pr2}$

进一步优化-

让我们为我们的方法计时,看看是否有改进性能的余地。在

In [126]: N = 5000

In [127]: pts = np.random.rand(3,N)

In [128]: %timeit np.triu_indices(N,1)
1 loops, best of 3: 413 ms per loop

In [129]: R,C = np.triu_indices(N,1)

In [130]: %timeit np.einsum('ij,ij->j',pts[:,R],pts[:,C])
1 loops, best of 3: 1.47 s per loop

在内存限制内,我们似乎无法对np.einsum进行优化。所以,让我们把焦点转移到np.triu_indices。在

对于N = 4,我们有:

In [131]: N = 4

In [132]: np.triu_indices(N,1)
Out[132]: (array([0, 0, 0, 1, 1, 2]), array([1, 2, 3, 2, 3, 3]))

它似乎在创造一种规律性的模式,但有点像是一种变化的模式。这可以用一个累加的和来写,它在35位置有移位。一般来说,我们最终会把它编码成这样-

def triu_indices_cumsum(N):

    # Length of R and C index arrays
    L = (N*(N-1))/2

    # Positions along the R and C arrays that indicate 
    # shifting to the next row of the full array
    shifts_idx = np.arange(2,N)[::-1].cumsum()

    # Initialize "shift" arrays for finally leading to R and C
    shifts1_arr = np.zeros(L,dtype=int)
    shifts2_arr = np.ones(L,dtype=int)

    # At shift positions along the shifts array set appropriate values, 
    # such that when cumulative summed would lead to desired R and C arrays.
    shifts1_arr[shifts_idx] = 1
    shifts2_arr[shifts_idx] = -np.arange(N-2)[::-1]

    # Finall cumsum to give R, C
    R_arr = shifts1_arr.cumsum()
    C_arr = shifts2_arr.cumsum()
    return R_arr, C_arr

让我们为不同的N's计时!在

In [133]: N = 100

In [134]: %timeit np.triu_indices(N,1)
10000 loops, best of 3: 122 µs per loop

In [135]: %timeit triu_indices_cumsum(N)
10000 loops, best of 3: 61.7 µs per loop

In [136]: N = 1000

In [137]: %timeit np.triu_indices(N,1)
100 loops, best of 3: 17 ms per loop

In [138]: %timeit triu_indices_cumsum(N)
100 loops, best of 3: 16.3 ms per loop

因此,对于体面的N's,基于triu_indices的定制cumsum可能值得一看!在

相关问题 更多 >

    热门问题