如何让numpy.einsum与sympy配合使用?

13 投票
3 回答
1910 浏览
提问于 2025-04-17 20:11

好的,我有几个包含sympy对象(表达式)的多维numpy数组。例如:

A = array([[1.0*cos(z0)**2 + 1.0, 1.0*cos(z0)],
          [1.0*cos(z0), 1.00000000000000]], dtype=object)

还有其他类似的数组。

我想用einsum来乘这些数组,因为我之前做数值计算时已经掌握了这个语法。问题是,当我尝试做类似下面的操作时:

einsum('ik,jkim,j', A, B, C)

我遇到了类型错误:

TypeError: invalid data type for einsum

当然,快速在谷歌上搜索一下,我发现einsum可能无法处理这种情况,但没有解释为什么。特别是,当我对这些数组使用numpy.dot()和numpy.tensordot()函数时,效果很好。我可以使用tensordot来完成我的需求,但一想到要把大约五十个像上面那样的爱因斯坦求和替换成嵌套的tensordot调用,我的脑袋就疼。更可怕的是,要调试那段代码,找出一个错误的索引交换。

长话短说,有人知道为什么tensordot可以处理这些对象,而einsum却不行吗?有没有什么解决办法?如果没有,能否给我一些建议,如何编写一个类似于einsum语法的嵌套tensordot调用的包装器(用数字代替字母也可以)?

3 个回答

2

这里有一个更简单的实现方法,它把 einsum 拆分成多个 tensordot

def einsum(string, *args):
    index_groups = map(list, string.split(','))
    assert len(index_groups) == len(args)
    tensor_indices_tuples = zip(index_groups, args)
    return reduce(einsum_for_two, tensor_indices_tuples)[1]

def einsum_for_two(tensor_indices1, tensor_indices2):
    string1, tensor1 = tensor_indices1
    string2, tensor2 = tensor_indices2
    sum_over_indices = set(string1).intersection(set(string2))
    new_string = string1 + string2
    axes = ([], [])
    for i in sum_over_indices:
        new_string.remove(i)
        new_string.remove(i)
        axes[0].append(string1.index(i))
        axes[1].append(string2.index(i))
    return new_string, np.tensordot(tensor1, tensor2, axes)

首先,它把 einsum 的参数分成了 (索引, 张量) 的元组。然后,它按照以下步骤处理这个列表:

  • 取前两个元组,使用一个简单的 einsum_for_two 对它们进行计算。同时,它会打印出新的索引签名。
  • 计算得到的 einsum_for_two 的值会和列表中的下一个元组一起,作为新的参数继续进行 einsum_for_two 的计算。
  • 这个过程会一直进行,直到只剩下一个元组为止。最后,索引签名会被丢弃,只返回张量。

这个方法可能比较慢(不过你反正是使用 object dtype)。它对输入的正确性检查也不多。

正如 @seberg 提到的,我的代码不适用于张量的迹运算。

6

有趣的是,添加 optimize="optimal" 这个选项对我来说有效。

当我用 einsum('ik,jkim,j', A, B, C) 时会出现错误,但

einsum('ik,jkim,j', A, B, C, optimize="optimal") 就能在 sympy 中完美运行。

4

Einsum基本上可以替代tensordot(注意不是dot,因为dot通常使用的是优化过的线性代数库),在代码上它们完全不同。

这里有一个叫做einsum的对象,虽然没有经过复杂情况的测试,但我觉得应该能用……在C语言中做同样的事情可能会更简单,因为你可以直接借用除了循环以外的所有部分,直接用真实的einsum函数。所以如果你有兴趣,可以试着实现一下,让更多人受益……

https://gist.github.com/seberg/5236560

我不能保证任何事情,尤其是对于一些奇怪的边缘情况。当然,我相信你也可以把einsum的表示法转换成tensordot的表示法,这样可能会更快,因为循环大部分会在C语言中执行……

撰写回答