Python的有效张量收缩

2024-04-20 06:24:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一段代码,其中有一个涉及张量收缩的瓶颈计算。假设我想计算一个张量a{I,j,k,l}(X),其单个X中的非零项为N~10^5,X表示一个有M个总点的网格,大约有M~1000。对于张量a的单个元素,方程的rhs类似于:

A{ijkl}(M)=Sum{M,n,p,q}S{i,j,M,n}(M)B{M,n,p,q}(M)T{p,q,k,l}(M)

此外,中间张量B_{m,n,p,q}(m)是通过数组的数值卷积获得的,因此:

B{m,n,p,q}(m)=(L{m,n}*F{p,q})(m)

其中“*”是卷积算子,所有张量的元素数都与A近似相同。我的问题与和的效率有关;考虑到问题的复杂性,计算a的单个rhs需要很长时间。我有一个“键”系统,每个张量元素都是通过字典中唯一的键组合(例如,T的键组合(p,q,k,l))来访问的。然后,该特定键的字典给出与该键关联的Numpy数组以执行操作,所有操作(卷积、乘法…)都使用Numpy完成。我已经看到,最耗时的部分实际上是由于嵌套的循环(I在张量的所有键(I,j,k,l)上循环,并且对于每个键,需要计算类似于上面的rhs)。有什么有效的方法可以做到这一点吗?考虑一下:

1)由于所有张量都是复数类型,因此使用4+1 D的简单numpy数组会导致高内存使用率 2)我尝试过几种方法:Numba在使用字典时非常有限,我需要的一些重要Numpy功能目前不受支持。例如,numpy.convolve()只接受前2个参数,但不接受“mode”参数,这大大减少了所需的卷积间隔在这种情况下,我不需要卷积的“full”输出

3)我最近的方法是尝试在这一部分中使用Cython实现所有功能。。。但考虑到代码的逻辑,这相当耗时,而且更容易出错

有没有关于如何使用Python处理这种复杂性的想法

谢谢


Tags: 方法代码功能numpy元素参数字典数组
1条回答
网友
1楼 · 发布于 2024-04-20 06:24:39

您必须使您的问题更加精确,其中还包括一个您已经尝试过的工作代码示例。例如,我们不清楚为什么在这个张量收缩中使用字典。字典查找对于这个计算来说似乎是一件令人厌倦的事情,但也许我没有理解你真正想做的事情

张量收缩实际上很容易在Python中实现(Numpy),有一些方法可以找到收缩张量的最佳方法,而且它们非常容易使用(np.einsum)

创建一些数据(这应该是问题的一部分)

import numpy as np
import time

i=20
j=20
k=20
l=20

m=20
n=20
p=20
q=20

#I don't know what complex 2 means, I assume it is complex128 (real and imaginary part are in float64)

#size of all arrays is 1.6e5
Sum_=np.random.rand(m,n,p,q).astype(np.complex128)
S_=np.random.rand(i,j,m,n).astype(np.complex128)
B_=np.random.rand(m,n,p,q).astype(np.complex128)
T_=np.random.rand(p,q,k,l).astype(np.complex128)

天真的方式

此代码基本上与使用Cython或Numba在循环中编写代码相同,无需调用BLAS例程(ZGEMM)或优化收缩顺序->;8个嵌套循环来完成此任务

t1=time.time()
A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_)
print(time.time()-t1)

这导致运行时间非常慢,大约330秒

如何将速度提高7700倍

%timeit A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
#42.9 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

为什么这要快得多?

让我们看看收缩路径和内部结构

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
print(path[1])

    #  Complete contraction:  mnpq,ijmn,mnpq,pqkl->ijkl
#         Naive scaling:  8
#     Optimized scaling:  6
#      Naive FLOP count:  1.024e+11
#  Optimized FLOP count:  2.562e+08
#   Theoretical speedup:  399.750
#  Largest intermediate:  1.600e+05 elements
#                                     
#scaling                  current                                remaining
#                                     
#   4             mnpq,mnpq->mnpq                     ijmn,pqkl,mnpq->ijkl
#   6             mnpq,ijmn->ijpq                          pqkl,ijpq->ijkl
#   6             ijpq,pqkl->ijkl                               ijkl->ijkl

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal",einsum_call=True)
print(path[1])

#[((2, 0), set(), 'mnpq,mnpq->mnpq', ['ijmn', 'pqkl', 'mnpq'], False), ((2, 0), {'n', 'm'}, 'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0), {'p', 'q'}, 'ijpq,pqkl->ijkl', ['ijkl'], True)]

在多个精心选择的步骤中进行收缩可将所需的失败次数减少400倍。但这不是einsum在这里做的唯一事情。只要看看'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0)True代表BLAS收缩->;tensordot呼叫->;(矩阵矩阵乘法)

从内部看,这基本上如下所示:

#consider X as a 4th order tensor {mnpq}
#consider Y as a 4th order tensor {ijmn}

X_=X.reshape(m*n,p*q)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
Y_=Y.reshape(i*j,m*n)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
res=np.dot(Y_,X_)           #-> dot is just a wrapper for highly optimized BLAS functions, in case of complex128 ZGEMM
output=res.reshape(i,j,p,q) #-> just another view on the data (4D), costs almost nothing (no copy, just a view)

相关问题 更多 >