使用NumPy快速旋转张量

44 投票
7 回答
8894 浏览
提问于 2025-04-16 11:36

在一个用Python编写的应用程序中,我需要对一个四阶张量进行旋转。实际上,我需要旋转很多张量,而且次数也很多,这成了我的瓶颈。我最简单的实现方式(如下所示)用了八个嵌套循环,速度似乎很慢,但我不知道怎么利用NumPy的矩阵操作来加快速度。我觉得我应该使用np.tensordot,但我不知道该怎么做。

从数学上讲,旋转后的张量元素T'是通过以下公式计算的:T'ijkl = Σ gia gjb gkc gld Tabcd,这里的求和是针对右边重复的索引。T和T'都是3*3*3*3的NumPy数组,而旋转矩阵g是一个3*3的NumPy数组。我的慢实现(每次调用大约需要0.04秒)如下。

#!/usr/bin/env python

import numpy as np

def rotT(T, g):
    Tprime = np.zeros((3,3,3,3))
    for i in range(3):
        for j in range(3):
            for k in range(3):
                for l in range(3):
                    for ii in range(3):
                        for jj in range(3):
                            for kk in range(3):
                                for ll in range(3):
                                    gg = g[ii,i]*g[jj,j]*g[kk,k]*g[ll,l]
                                    Tprime[i,j,k,l] = Tprime[i,j,k,l] + \
                                         gg*T[ii,jj,kk,ll]
    return Tprime

if __name__ == "__main__":

    T = np.array([[[[  4.66533067e+01,  5.84985000e-02, -5.37671310e-01],
                    [  5.84985000e-02,  1.56722231e+01,  2.32831900e-02],
                    [ -5.37671310e-01,  2.32831900e-02,  1.33399259e+01]],
                   [[  4.60051700e-02,  1.54658176e+01,  2.19568200e-02],
                    [  1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
                    [  2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
                   [[ -5.35577630e-01,  1.95558600e-02,  1.31108757e+01],
                    [  1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
                    [  1.31108757e+01, -6.67615000e-03,  6.90486240e-01]]],
                  [[[  4.60051700e-02,  1.54658176e+01,  2.19568200e-02],
                    [  1.54658176e+01, -5.18223500e-02, -1.52814920e-01],
                    [  2.19568200e-02, -1.52814920e-01, -2.43874100e-02]],
                   [[  1.57414726e+01, -3.86167500e-02, -1.55971950e-01],
                    [ -3.86167500e-02,  4.65601977e+01, -3.57741000e-02],
                    [ -1.55971950e-01, -3.57741000e-02,  1.34215636e+01]],
                   [[  2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
                    [ -1.49072770e-01, -3.63410500e-02,  1.32039847e+01],
                    [ -7.38843000e-03,  1.32039847e+01,  1.38172700e-02]]],
                  [[[ -5.35577630e-01,  1.95558600e-02,  1.31108757e+01],
                    [  1.95558600e-02, -1.51342210e-01, -6.67615000e-03],
                    [  1.31108757e+01, -6.67615000e-03,  6.90486240e-01]],
                   [[  2.58256300e-02, -1.49072770e-01, -7.38843000e-03],
                    [ -1.49072770e-01, -3.63410500e-02,  1.32039847e+01],
                    [ -7.38843000e-03,  1.32039847e+01,  1.38172700e-02]],
                   [[  1.33639532e+01, -1.26331100e-02,  6.84650400e-01],
                    [ -1.26331100e-02,  1.34222177e+01,  1.67851800e-02],
                    [  6.84650400e-01,  1.67851800e-02,  4.89151396e+01]]]])

    g = np.array([[ 0.79389393,  0.54184237,  0.27593346],
                  [-0.59925749,  0.62028664,  0.50609776],
                  [ 0.10306737, -0.56714313,  0.8171449 ]])

    for i in range(100):
        Tprime = rotT(T,g)

有没有办法让这个运行得更快?如果能让代码适用于其他阶数的张量,那就更好了,不过这不是最重要的。

7 个回答

19

多亏了M. Wiebe的努力,Numpy的下一个版本(可能是1.6)将会让这件事变得更简单:

>>> Trot = np.einsum('ai,bj,ck,dl,abcd->ijkl', g, g, g, g, T)

不过,Philipp的方法现在快了3倍,但也许还有改进的空间。速度差异主要是因为tensordot能够把整个操作展开成一个单一的矩阵乘法,这样就可以交给BLAS处理,从而避免了小数组带来的很多额外开销——而对于一般的爱因斯坦求和,这种方法就不行,因为并不是所有可以用这种形式表达的操作都能简化成一个单一的矩阵乘法。

34

这里是用一个Python循环来实现的方法:

def rotT(T, g):
    Tprime = T
    for i in range(4):
        slices = [None] * 4
        slices[i] = slice(None)
        slices *= 2
        Tprime = g[slices].T * Tprime
    return Tprime.sum(-1).sum(-1).sum(-1).sum(-1)

老实说,刚开始看这个可能有点难理解,但其实它运行起来要快很多 :)

42

要使用 tensordot,首先计算 g 张量的外积:

def rotT(T, g):
    gg = np.outer(g, g)
    gggg = np.outer(gg, gg).reshape(4 * g.shape)
    axes = ((0, 2, 4, 6), (0, 1, 2, 3))
    return np.tensordot(gggg, T, axes)

在我的系统上,这个方法比Sven的解决方案快大约七倍。如果 g 张量不经常变化,你还可以缓存 gggg 张量。这样做并开启一些微优化(比如内联 tensordot 的代码,不做检查,不使用通用形状),你还可以让速度再快两倍:

def rotT(T, gggg):
    return np.dot(gggg.transpose((1, 3, 5, 7, 0, 2, 4, 6)).reshape((81, 81)),
                  T.reshape(81, 1)).reshape((3, 3, 3, 3))

这是我在家用笔记本上用 timeit 测试的结果(进行了500次迭代):

Your original code: 19.471129179
Sven's code: 0.718412876129
My first code: 0.118047952652
My second code: 0.0690279006958

我在工作电脑上的数据是:

Your original code: 9.77922987938
Sven's code: 0.137110948563
My first code: 0.0569641590118
My second code: 0.0308079719543

撰写回答