如何提升Python中odeint的速度?
我正在使用Python和scipy包中的odeint来解决大量(大约100万个)耦合的常微分方程(ODE)。这些方程可以通过一些矩阵乘法的和来表示,我使用numpy并且支持blas来处理这些计算。我的问题是,这个过程花费的时间非常长。当我分析代码时,我发现大部分时间都花在odeint上做其他事情,而不是在计算右侧(rhs)。以下是分析器显示的五个最耗时的调用:
ncalls tottime percall cumtime percall filename:lineno(function)
5 1547.915 309.583 1588.170 317.634 {scipy.integrate._odepack.odeint}
60597 11.535 0.000 23.751 0.000 terms3D.py:5(two_body_evolution)
121194 11.242 0.000 11.242 0.000 {numpy.core._dotblas.dot}
60597 10.145 0.000 15.460 0.000 generator.py:13(Gs2)
121203 3.615 0.000 3.615 0.000 {method 'repeat' of 'numpy.ndarray' objects}
右侧的内容基本上是two_body_evolution和Gs2。这个分析是针对大约7000个耦合的ODE,而这里是针对大约4000个的情况:
ncalls tottime percall cumtime percall filename:lineno(function)
5 259.427 51.885 273.316 54.663 {scipy.integrate._odepack.odeint}
30832 3.809 0.000 7.864 0.000 terms3D.py:5(two_body_evolution)
61664 3.650 0.000 3.650 0.000 {numpy.core._dotblas.dot}
30832 3.464 0.000 5.637 0.000 generator.py:13(Gs2)
61673 1.280 0.000 1.280 0.000 {method 'repeat' of 'numpy.ndarray' objects}
所以我主要的问题是,odeint中“隐藏”的时间随着方程数量的增加而急剧增加。你们有没有想法为什么会这样,以及如何提高性能呢?
谢谢你的时间
Oscar Åkerlund
2 个回答
一千万个方程可不是个小数字。
你为什么说它“扩展得很糟糕”?矩阵加法对于 m x m 的矩阵来说是 O(m^2)
,而乘法是 O(m^3)
。你提到的墙钟时间和方程数量/自由度之间的关系,只能描述一条直线。我建议你在4K和1000万之间选几个中间点,看看大O符号是如何表现扩展性的。把结果拟合成一个三次多项式,看看墙钟时间和自由度之间的关系,这样能告诉你扩展的情况。
你的方程是线性的还是非线性的?是静态的还是瞬态的?根据问题的类型,你可能可以调整一些其他参数,比如时间步长、收敛标准、积分方案选择等等。
这至少是导致时间消耗的一个可能原因:
如果你没有给odeint
提供雅可比矩阵(也就是LSODA),它就会通过有限差分的方法来计算这个矩阵。而且,如果它认为问题比较复杂,它可能还会尝试反转雅可比矩阵,这个过程的计算量是O(m^3),也就是说,随着变量数量的增加,计算的难度会大幅增加。
你可以通过强制odeint
使用带状雅可比矩阵来减少这些操作所需的时间,方法是给例程传入合适的ml
和mu
参数。你不需要提供Dfun,这些参数同样适用于通过微分计算得到的雅可比矩阵。