如何创建任意长度的numpy.piecewise函数?(遇到lambda问题)

0 投票
1 回答
1137 浏览
提问于 2025-04-18 08:05

我正在尝试为我的数据绘制一个分段拟合的图,但我需要能够处理任意数量的线段。有时候需要三段,有时候只需要两段。我把拟合的系数存储在一个叫做actable的地方,把线段的边界存储在btable里。

以下是我边界的示例值:

btable = [[0.00499999989, 0.0244274978], [0.0244275965, 0.0599999987]]

以下是我系数的示例值:

actable = [[0.0108687987, -0.673182865, 14.6420775], [0.00410866373, -0.0588355861, 1.07750032]]

这是我的代码:

rfig = plt.figure()
<>various other plot specifications<>
x = np.arange(0.005, 0.06, 0.0001)
y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x=x: np.log10(actable[j][0] + actable[j][2] * x + actable[j][2] * x**2) for j in list(range(len(actable)))])
plt.plot(x, y)

问题是,lambda会自动设置为列表中的最后一个实例,所以它把最后一段的系数用在了所有的段上。我不知道如何在不使用lambda的情况下做分段函数。

目前,我在用一种不太好的方法来解决这个问题:

if len(btable) == 2:
    y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x: np.log10(actable[0][0] + actable[0][1] * x + actable[0][2] * x**2), lambda x: np.log10(actable[1][0] + actable[1][1] * x + actable[1][2] * x**2)])
else if len(btable) == 3:
    y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x: np.log10(actable[0][0] + actable[0][1] * x + actable[0][2] * x**2), lambda x: np.log10(actable[1][0] + actable[1][1] * x + actable[1][2] * x**2), lambda x: np.log10(actable[2][0] + actable[2][1] * x + actable[2][2] * x**2)])
else
    print('Oh no!  You have fewer than 2 or more than 3 segments!')

但这样让我感觉不太舒服。我知道一定有更好的解决办法。有人能帮忙吗?

1 个回答

0

这个问题挺常见的,以至于Python的官方文档里有一篇文章,标题是“为什么在循环中定义的不同值的lambda都返回相同的结果?”,里面给出了一种解决办法:在循环中创建一个局部变量,用来保存循环变量的值,这样就能在函数里捕捉到这些变化的值。

也就是说,在定义y的时候,只需要把

[lambda x=x: np.log10(actable[j][0] + actable[j][1] * x + actable[j][2] * x**2) for j in range(len(actable))]

替换成

[lambda x=x, k=j: np.log10(actable[k][0] + actable[k][1] * x + actable[k][2] * x**2) for j in range(len(actable))]

顺便提一下,可以用单边不等式来指定numpy.piecewise的范围:最后一个返回True的条件会触发相应的函数。(这个优先级有点反直觉;用第一个返回True的条件会更自然,就像SymPy那样)。如果断点是按升序排列的,那么应该使用“x>=”的不等式:

breaks = np.arange(0, 10)       # breakpoints
coeff = np.arange(0, 20, 2)     # coefficients to use
x = np.arange(0, 10, 0.1)
y = np.piecewise(x, [x >= b for b in breaks], [lambda x=x, a=c: a*x for c in coeff])

在这里,每个系数会用于以对应的断点“开始”的区间;比如,系数c=0用于范围0<=x<1,系数c=2用于范围1<=x<2,依此类推。

撰写回答