Python - matplotlib:寻找折线图的交点

14 投票
3 回答
10569 浏览
提问于 2025-04-17 06:07

我有一个可能很简单的问题,这个问题让我困扰了很久。有没有什么简单的方法可以在Python的matplotlib中返回两个绘制的数据集的交点?

为了更详细地说明,我有这样的数据:

x=[1.4,2.1,3,5.9,8,9,23]
y=[2.3,3.1,1,3.9,8,9,11]
x1=[1,2,3,4,6,8,9]
y1=[4,12,7,1,6.3,8.5,12]
plot(x1,y1,'k-',x,y,'b-')

这个例子中的数据完全是随意的。现在我想知道是否有一个简单的内置函数,我一直没注意到,它可以给我返回这两个图形的精确交点。

希望我说得够清楚,也希望我没有遗漏什么显而易见的东西……

3 个回答

0

可以使用这个函数

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

x1 = np.array([1.4,2.1,3,5.9,8,9,23], dtype=float)
y1 = np.array([2.3,3.1,1,3.9,8,9,11], dtype=float)
x2 = np.array([1,2,3,4,6,8,9], dtype=float)
y2 = np.array([4,12,7,1,6.3,8.5,12], dtype=float)

# plt.plot(x1,y1,'ko-',x2,y2,'bo-')
# plt.show()

##https://stackoverflow.com/questions/55225385/intersection-point-between-2d-arrays
def intersection(X1, X2):
    x = np.union1d(X1[0], X2[0])
    y1 = np.interp(x, X1[0], X1[1])
    y2 = np.interp(x, X2[0], X2[1])
    dy = y1 - y2

    ind = (dy[:-1] * dy[1:] < 0).nonzero()[0]
    x1, x2 = x[ind], x[ind+1]
    dy1, dy2 = dy[ind], dy[ind+1]
    y11, y12 = y1[ind], y1[ind+1]
    y21, y22 = y2[ind], y2[ind+1]

    x_int = x1 - (x2 - x1) * dy1 / (dy2 - dy1)
    y_int = y11 + (y12 - y11) * (x_int - x1) / (x2 - x1)
    return x_int, y_int

res = intersection([x1, y1], [x2, y2])
print(res)
if len(res)!=0:
    t=np.array(res); print(t)
    d= pd.DataFrame(t.T, columns= ['x', 'y']); print(d)

plt.plot(x1,y1,'ko-',x2,y2,'bo-', d.x, d.y, 'ro')
plt.show()

在这里输入图片描述

1

参数化解决方案

如果我们有两个序列 {x1,y1} 和 {x2,y2},它们定义了一些任意的 (x,y) 曲线,而不是 y(x) 曲线,那么我们就需要用参数化的方法来寻找它们的交点。因为这个过程不是特别明显,而且 @unutbu 的解决方案使用了一个已经不再支持的 SciPy 插值器,所以我觉得重新讨论这个问题可能会有帮助。

import numpy as np
from numpy.linalg import norm
from scipy.optimize import fsolve
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

x1_array = np.array([1,2,3,4,6,8,9])
y1_array = np.array([4,12,7,1,6.3,8.5,12])
x2_array = np.array([1.4,2.1,3,5.9,8,9,23])
y2_array = np.array([2.3,3.1,1,3.9,8,9,11])

s1_array = np.linspace(0,1,num=len(x1_array))
s2_array = np.linspace(0,1,num=len(x2_array))

# Arguments given to interp1d:
#  - extrapolate: to make sure we don't get a fatal value error when fsolve searches
#                 beyond the bounds of [0,1]
#  - copy: use refs to the arrays
#  - assume_sorted: because s_array ('x') increases monotonically across [0,1]
kwargs_ = dict(fill_value='extrapolate', copy=False, assume_sorted=True)
x1_interp = interp1d(s1_array,x1_array, **kwargs_)
y1_interp = interp1d(s1_array,y1_array, **kwargs_)
x2_interp = interp1d(s2_array,x2_array, **kwargs_)
y2_interp = interp1d(s2_array,y2_array, **kwargs_)
xydiff_lambda = lambda s12: (np.abs(x1_interp(s12[0])-x2_interp(s12[1])),
                             np.abs(y1_interp(s12[0])-y2_interp(s12[1])))

s12_intercept, _, ier, mesg \
    = fsolve(xydiff_lambda, [0.5, 0.3], full_output=True) 

xy1_intercept = x1_interp(s12_intercept[0]),y1_interp(s12_intercept[0])
xy2_intercept = x2_interp(s12_intercept[1]),y2_interp(s12_intercept[1])

plt.plot(x1_interp(s1_array),y1_interp(s1_array),'b.', ls='-', label='x1 data')
plt.plot(x2_interp(s2_array),y2_interp(s2_array),'r.', ls='-', label='x2 data')
if s12_intercept[0]>0 and s12_intercept[0]<1:
    plt.plot(*xy1_intercept,'bo', ms=12, label='x1 intercept')
    plt.plot(*xy2_intercept,'ro', ms=8, label='x2 intercept')
plt.legend()

print('intercept @ s1={}, s2={}\n'.format(s12_intercept[0],s12_intercept[1]), 
      'intercept @ xy1={}\n'.format(np.array(xy1_intercept)), 
      'intercept @ xy2={}\n'.format(np.array(xy2_intercept)), 
      'fsolve apparent success? {}: "{}"\n'.format(ier==1,mesg,), 
      'is intercept really good? {}\n'.format(s12_intercept[0]>=0 and s12_intercept[0]<=1 
      and s12_intercept[1]>=0 and s12_intercept[1]<=1 
      and np.isclose(0,norm(xydiff_lambda(s12_intercept)))) )

对于这个特定的初始猜测 [0.5,0.3],这个方法会返回:

intercept @ s1=0.4761904761904762, s2=0.3825944170771757
intercept @ xy1=[3.85714286 1.85714286]
intercept @ xy2=[3.85714286 1.85714286]
fsolve apparent success? True: "The solution converged."
is intercept really good? True

这个方法只找到一个交点:我们需要尝试多个初始猜测(就像 @unutbu 的代码那样),检查它们的准确性,并使用 np.close 来消除重复的结果。需要注意的是,fsolve 可能会错误地在返回值 ier 中显示成功找到交点,这就是为什么这里需要额外的检查。

这是这个解决方案的图示: 示例解决方案

25

我们可以使用 scipy.interpolate.PiecewisePolynomial 来创建一些函数,这些函数是由你的分段线性数据定义的。

p1=interpolate.PiecewisePolynomial(x1,y1[:,np.newaxis])
p2=interpolate.PiecewisePolynomial(x2,y2[:,np.newaxis])

然后我们可以计算这两个函数之间的差值,

def pdiff(x):
    return p1(x)-p2(x)

并使用 optimize.fsolve 来找到 pdiff 的根:

import scipy.interpolate as interpolate
import scipy.optimize as optimize
import numpy as np

x1=np.array([1.4,2.1,3,5.9,8,9,23])
y1=np.array([2.3,3.1,1,3.9,8,9,11])
x2=np.array([1,2,3,4,6,8,9])
y2=np.array([4,12,7,1,6.3,8.5,12])    

p1=interpolate.PiecewisePolynomial(x1,y1[:,np.newaxis])
p2=interpolate.PiecewisePolynomial(x2,y2[:,np.newaxis])

def pdiff(x):
    return p1(x)-p2(x)

xs=np.r_[x1,x2]
xs.sort()
x_min=xs.min()
x_max=xs.max()
x_mid=xs[:-1]+np.diff(xs)/2
roots=set()
for val in x_mid:
    root,infodict,ier,mesg = optimize.fsolve(pdiff,val,full_output=True)
    # ier==1 indicates a root has been found
    if ier==1 and x_min<root<x_max:
        roots.add(root[0])
roots=list(roots)        
print(np.column_stack((roots,p1(roots),p2(roots))))

结果是

[[ 3.85714286  1.85714286  1.85714286]
 [ 4.60606061  2.60606061  2.60606061]]

第一列是 x 值,第二列是第一个 PiecewisePolynomial 在 x 处的 y 值,第三列是第二个 PiecewisePolynomial 的 y 值。

撰写回答