2024-06-17 11:11:36 发布
网友
我对完整的Python代码(带有数学公式)感兴趣,其中包含从头开始计算自然Cubic Splines所需的所有计算。如果可能,快速(例如基于Numpy)
我创建这个问题只是为了分享我最近在学习三次样条曲线时从零开始(基于维基百科)编写的代码(作为答案)
我基于俄语Wikipedia Article编写了以下代码,因为我看到几乎相同的描述和公式都是用英语Article编写的
为了加快计算速度,我同时使用了Numpy和Numba
为了检查代码的正确性,我通过与scipy.interpolate.CubicSpline的自然三次样条函数的参考实现进行了比较测试,您可以在我的代码中看到np.allclose(...)断言,它证明了我的公式是正确的
np.allclose(...)
此外,我还做了计时:
calc (spline_scipy): Timed best=2.712 ms, mean=2.792 +- 0.1 ms calc (spline_numba): Timed best=916.000 us, mean=938.868 +- 17.9 us speedup: 2.973 use (spline_scipy): Timed best=5.262 ms, mean=5.320 +- 0.1 ms use (spline_numba): Timed best=4.745 ms, mean=5.420 +- 0.3 ms speedup: 0.981
这表明我的样条曲线参数计算速度大约是Scipy版本的3x倍,并且样条曲线(针对给定x的计算)的使用速度与Scipy相同
3x
x
运行下面的代码需要一次性安装以下软件包python -m pip install numpy numba scipy timerit,此处scipy和timerit仅用于测试目的,实际算法不需要
python -m pip install numpy numba scipy timerit
scipy
timerit
代码绘制曲线图,显示Scipy和Numba版本的原始多线和两条样条曲线近似值,可以看到Scipy和Numba线相同(意味着样条曲线计算相同):
代码:
Try it online!
import numpy as np, numba # Solves linear system given by Tridiagonal Matrix # Helper for calculating cubic splines @numba.njit( [f'f{ii}[:](f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:])' for ii in (4, 8)], cache = True, fastmath = True, inline = 'always') def tri_diag_solve(A, B, C, F): n = B.size assert A.ndim == B.ndim == C.ndim == F.ndim == 1 and ( A.size == B.size == C.size == F.size == n ) #, (A.shape, B.shape, C.shape, F.shape) Bs, Fs = np.zeros_like(B), np.zeros_like(F) Bs[0], Fs[0] = B[0], F[0] for i in range(1, n): Bs[i] = B[i] - A[i] / Bs[i - 1] * C[i - 1] Fs[i] = F[i] - A[i] / Bs[i - 1] * Fs[i - 1] x = np.zeros_like(B) x[-1] = Fs[-1] / Bs[-1] for i in range(n - 2, -1, -1): x[i] = (Fs[i] - C[i] * x[i + 1]) / Bs[i] return x # Calculate cubic spline params @numba.njit( #[f'(f{ii}, f{ii}, f{ii}, f{ii})(f{ii}[:], f{ii}[:])' for ii in (4, 8)], cache = True, fastmath = True, inline = 'always') def calc_spline_params(x, y): a = y h = np.diff(x) c = np.concatenate((np.zeros((1,), dtype = y.dtype), np.append(tri_diag_solve(h[:-1], (h[:-1] + h[1:]) * 2, h[1:], ((a[2:] - a[1:-1]) / h[1:] - (a[1:-1] - a[:-2]) / h[:-1]) * 3), 0))) d = np.diff(c) / (3 * h) b = (a[1:] - a[:-1]) / h + (2 * c[1:] + c[:-1]) / 3 * h return a[1:], b, c[1:], d # Spline value calculating function, given params and "x" @numba.njit( [f'f{ii}[:](f{ii}[:], i8[:], f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:])' for ii in (4, 8)], cache = True, fastmath = True, inline = 'always') def func_spline(x, ix, x0, a, b, c, d): dx = x - x0[1:][ix] return a[ix] + (b[ix] + (c[ix] + d[ix] * dx) * dx) * dx @numba.njit( [f'i8[:](f{ii}[:], f{ii}[:], b1)' for ii in (4, 8)], cache = True, fastmath = True, inline = 'always') def searchsorted_merge(a, b, sort_b): ix = np.zeros((len(b),), dtype = np.int64) if sort_b: ib = np.argsort(b) pa, pb = 0, 0 while pb < len(b): if pa < len(a) and a[pa] < (b[ib[pb]] if sort_b else b[pb]): pa += 1 else: ix[pb] = pa pb += 1 return ix # Compute piece-wise spline function for "x" out of sorted "x0" points @numba.njit([f'f{ii}[:](f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:], f{ii}[:])' for ii in (4, 8)], cache = True, fastmath = True, inline = 'always') def piece_wise_spline(x, x0, a, b, c, d): xsh = x.shape x = x.ravel() #ix = np.searchsorted(x0[1 : -1], x) ix = searchsorted_merge(x0[1 : -1], x, False) y = func_spline(x, ix, x0, a, b, c, d) y = y.reshape(xsh) return y def test(): import matplotlib.pyplot as plt, scipy.interpolate from timerit import Timerit Timerit._default_asciimode = True np.random.seed(0) def f(n): x = np.sort(np.random.uniform(0., n / 5 * np.pi, (n,))).astype(np.float64) return x, (np.sin(x) * 5 + np.sin(1 + 2.5 * x) * 3 + np.sin(2 + 0.5 * x) * 2).astype(np.float64) def spline_numba(x0, y0): a, b, c, d = calc_spline_params(x0, y0) return lambda x: piece_wise_spline(x, x0, a, b, c, d) def spline_scipy(x0, y0): f = scipy.interpolate.CubicSpline(x0, y0, bc_type = 'natural') return lambda x: f(x) def timings(): x0, y0 = f(10000) s, t = {}, [] gs = [spline_scipy, spline_numba] spline_numba(np.copy(x0[::3]), np.copy(y0[::3])) # pre-compile numba for g in gs: print('calc (', g.__name__, '): ', sep = '', end = '', flush = True) tim = Timerit(num = 150, verbose = 1) for _ in tim: s_ = g(x0, y0) s[g.__name__] = s_ t.append(tim.mean()) if len(t) >= 2: print('speedup:', round(t[-2] / t[-1], 3)) print() x = np.linspace(x0[0], x0[-1], 50000, dtype = np.float64) t = [] s['spline_numba'](np.copy(x[::3])) # pre-compile numba for i in range(len(s)): print('use (', gs[i].__name__, '): ', sep = '', end = '', flush = True) tim = Timerit(num = 100, verbose = 1) sg = s[gs[i].__name__] for _ in tim: sg(x) t.append(tim.mean()) if len(t) >= 2: print('speedup:', round(t[-2] / t[-1], 3)) x0, y0 = f(50) timings() shift = 3 x = np.linspace(x0[0], x0[-1], 1000, dtype = np.float64) ys = spline_scipy(x0, y0)(x) yn = spline_numba(x0, y0)(x) assert np.allclose(ys, yn), np.absolute(ys - yn).max() plt.plot(x0, y0, label = 'orig') plt.plot(x, ys, label = 'spline_scipy') plt.plot(x, yn, '-.', label = 'spline_numba') plt.legend() plt.show() if __name__ == '__main__': test()
我基于俄语Wikipedia Article编写了以下代码,因为我看到几乎相同的描述和公式都是用英语Article编写的
为了加快计算速度,我同时使用了Numpy和Numba
为了检查代码的正确性,我通过与scipy.interpolate.CubicSpline的自然三次样条函数的参考实现进行了比较测试,您可以在我的代码中看到
np.allclose(...)
断言,它证明了我的公式是正确的此外,我还做了计时:
这表明我的样条曲线参数计算速度大约是Scipy版本的
3x
倍,并且样条曲线(针对给定x
的计算)的使用速度与Scipy相同运行下面的代码需要一次性安装以下软件包
python -m pip install numpy numba scipy timerit
,此处scipy
和timerit
仅用于测试目的,实际算法不需要代码绘制曲线图,显示Scipy和Numba版本的原始多线和两条样条曲线近似值,可以看到Scipy和Numba线相同(意味着样条曲线计算相同):
代码:
Try it online!
相关问题 更多 >
编程相关推荐