如何让numpy.einsum与sympy配合使用?
好的,我有几个包含sympy对象(表达式)的多维numpy数组。例如:
A = array([[1.0*cos(z0)**2 + 1.0, 1.0*cos(z0)],
[1.0*cos(z0), 1.00000000000000]], dtype=object)
还有其他类似的数组。
我想用einsum来乘这些数组,因为我之前做数值计算时已经掌握了这个语法。问题是,当我尝试做类似下面的操作时:
einsum('ik,jkim,j', A, B, C)
我遇到了类型错误:
TypeError: invalid data type for einsum
当然,快速在谷歌上搜索一下,我发现einsum可能无法处理这种情况,但没有解释为什么。特别是,当我对这些数组使用numpy.dot()和numpy.tensordot()函数时,效果很好。我可以使用tensordot来完成我的需求,但一想到要把大约五十个像上面那样的爱因斯坦求和替换成嵌套的tensordot调用,我的脑袋就疼。更可怕的是,要调试那段代码,找出一个错误的索引交换。
长话短说,有人知道为什么tensordot可以处理这些对象,而einsum却不行吗?有没有什么解决办法?如果没有,能否给我一些建议,如何编写一个类似于einsum语法的嵌套tensordot调用的包装器(用数字代替字母也可以)?
3 个回答
这里有一个更简单的实现方法,它把 einsum
拆分成多个 tensordot
。
def einsum(string, *args):
index_groups = map(list, string.split(','))
assert len(index_groups) == len(args)
tensor_indices_tuples = zip(index_groups, args)
return reduce(einsum_for_two, tensor_indices_tuples)[1]
def einsum_for_two(tensor_indices1, tensor_indices2):
string1, tensor1 = tensor_indices1
string2, tensor2 = tensor_indices2
sum_over_indices = set(string1).intersection(set(string2))
new_string = string1 + string2
axes = ([], [])
for i in sum_over_indices:
new_string.remove(i)
new_string.remove(i)
axes[0].append(string1.index(i))
axes[1].append(string2.index(i))
return new_string, np.tensordot(tensor1, tensor2, axes)
首先,它把 einsum
的参数分成了 (索引, 张量) 的元组。然后,它按照以下步骤处理这个列表:
- 取前两个元组,使用一个简单的
einsum_for_two
对它们进行计算。同时,它会打印出新的索引签名。 - 计算得到的
einsum_for_two
的值会和列表中的下一个元组一起,作为新的参数继续进行einsum_for_two
的计算。 - 这个过程会一直进行,直到只剩下一个元组为止。最后,索引签名会被丢弃,只返回张量。
这个方法可能比较慢(不过你反正是使用 object
dtype
)。它对输入的正确性检查也不多。
正如 @seberg 提到的,我的代码不适用于张量的迹运算。
有趣的是,添加 optimize="optimal"
这个选项对我来说有效。
当我用 einsum('ik,jkim,j', A, B, C)
时会出现错误,但
用 einsum('ik,jkim,j', A, B, C, optimize="optimal")
就能在 sympy 中完美运行。
Einsum基本上可以替代tensordot(注意不是dot,因为dot通常使用的是优化过的线性代数库),在代码上它们完全不同。
这里有一个叫做einsum的对象,虽然没有经过复杂情况的测试,但我觉得应该能用……在C语言中做同样的事情可能会更简单,因为你可以直接借用除了循环以外的所有部分,直接用真实的einsum函数。所以如果你有兴趣,可以试着实现一下,让更多人受益……
https://gist.github.com/seberg/5236560
我不能保证任何事情,尤其是对于一些奇怪的边缘情况。当然,我相信你也可以把einsum的表示法转换成tensordot的表示法,这样可能会更快,因为循环大部分会在C语言中执行……