如何提升Python中odeint的速度?

5 投票
2 回答
3402 浏览
提问于 2025-04-16 19:18

我正在使用Python和scipy包中的odeint来解决大量(大约100万个)耦合的常微分方程(ODE)。这些方程可以通过一些矩阵乘法的和来表示,我使用numpy并且支持blas来处理这些计算。我的问题是,这个过程花费的时间非常长。当我分析代码时,我发现大部分时间都花在odeint上做其他事情,而不是在计算右侧(rhs)。以下是分析器显示的五个最耗时的调用:

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
5       1547.915  309.583 1588.170  317.634 {scipy.integrate._odepack.odeint}
60597   11.535    0.000   23.751    0.000   terms3D.py:5(two_body_evolution)
121194  11.242    0.000   11.242    0.000   {numpy.core._dotblas.dot}
60597   10.145    0.000   15.460    0.000   generator.py:13(Gs2)
121203   3.615    0.000   3.615     0.000   {method 'repeat' of 'numpy.ndarray' objects}

右侧的内容基本上是two_body_evolution和Gs2。这个分析是针对大约7000个耦合的ODE,而这里是针对大约4000个的情况:

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
5        259.427  51.885   273.316  54.663 {scipy.integrate._odepack.odeint}
30832    3.809    0.000    7.864    0.000  terms3D.py:5(two_body_evolution)
61664    3.650    0.000    3.650    0.000  {numpy.core._dotblas.dot}
30832    3.464    0.000    5.637    0.000  generator.py:13(Gs2)
61673    1.280    0.000    1.280    0.000  {method 'repeat' of 'numpy.ndarray' objects}

所以我主要的问题是,odeint中“隐藏”的时间随着方程数量的增加而急剧增加。你们有没有想法为什么会这样,以及如何提高性能呢?

谢谢你的时间

Oscar Åkerlund

2 个回答

2

一千万个方程可不是个小数字。

你为什么说它“扩展得很糟糕”?矩阵加法对于 m x m 的矩阵来说是 O(m^2),而乘法是 O(m^3)。你提到的墙钟时间和方程数量/自由度之间的关系,只能描述一条直线。我建议你在4K和1000万之间选几个中间点,看看大O符号是如何表现扩展性的。把结果拟合成一个三次多项式,看看墙钟时间和自由度之间的关系,这样能告诉你扩展的情况。

你的方程是线性的还是非线性的?是静态的还是瞬态的?根据问题的类型,你可能可以调整一些其他参数,比如时间步长、收敛标准、积分方案选择等等。

5

这至少是导致时间消耗的一个可能原因:

如果你没有给odeint提供雅可比矩阵(也就是LSODA),它就会通过有限差分的方法来计算这个矩阵。而且,如果它认为问题比较复杂,它可能还会尝试反转雅可比矩阵,这个过程的计算量是O(m^3),也就是说,随着变量数量的增加,计算的难度会大幅增加。

你可以通过强制odeint使用带状雅可比矩阵来减少这些操作所需的时间,方法是给例程传入合适的mlmu参数。你不需要提供Dfun,这些参数同样适用于通过微分计算得到的雅可比矩阵。

撰写回答