Sympy多项式的KKT点最佳寻找方法

1 投票
1 回答
49 浏览
提问于 2025-04-12 02:07

我正在尝试在一些低度多项式上运行梯度下降(GD),比如说三次多项式,变量数量是n,并且在一个被限制在超立方体[-1,1]^n的范围内。我想把GD的结束点和用Sympy计算的实际KKT点进行比较。

举个例子,我有一个包含三个变量a、b和c的Sympy表达式(也就是n=3),像这样:

0.2*a**3 - 0.1*a**2*b - 0.9*a**2*c + 0.5*a**2 - 0.1*a*b**2 - 0.4*a*b*c + 0.5*a*b - 0.6*a*c**2 - 0.5*a*c - 0.6*a - 0.1*b**3 + 0.4*b**2*c - 0.4*b**2 + 0.7*b*c**2 - 0.4*b*c + 0.9*b + 0.09*c**3 - 0.6*c**2 + 0.4*c + 0.22

那么,找到这个表达式所有KKT点的最佳、最快的方法是什么呢?并且需要满足以下约束条件:

-1 <= a <= 1
-1 <= b <= 1
-1 <= c <= 1

(理想情况下,搜索应该仅限于GD结束点附近,以节省计算时间,但我不太确定如何严格做到这一点)

1 个回答

3

先把这个表达式和它的雅可比矩阵(就是一种数学工具)预先编译成函数对象,然后把它们传递给 minimize 函数。需要注意的是,这个方法会比你想的梯度下降算法更聪明一些,但这可能对你有帮助,也可能没有。默认情况下,它使用的是Kraft的顺序最小二乘法。

import string

import numpy as np
import scipy.optimize
import sympy

expr_str = (
    '0.2*a**3 - 0.1*a**2*b - 0.9*a**2*c + 0.5*a**2 - 0.1*a*b**2 - 0.4*a*b*c '
    '+ 0.5*a*b - 0.6*a*c**2 - 0.5*a*c - 0.6*a - 0.1*b**3 + 0.4*b**2*c '
    '- 0.4*b**2 + 0.7*b*c**2 - 0.4*b*c + 0.9*b + 0.09*c**3 - 0.6*c**2 + 0.4*c + 0.22'
)
symbols = {
    c: sympy.Symbol(name=c, real=True, finite=True)
    for c in string.ascii_lowercase
    if c in expr_str
}
expr = sympy.parse_expr(s=expr_str, local_dict=symbols)
expr_callable = sympy.lambdify(
    args=[symbols.values()], expr=expr,
)

jac = [
    sympy.diff(expr, sym)
    for sym in symbols.values()
]
jac_callable = sympy.lambdify(
    args=[symbols.values()], expr=jac,
)

result = scipy.optimize.minimize(
    fun=expr_callable,
    jac=jac_callable,
    x0=np.zeros(len(symbols)),
    bounds=scipy.optimize.Bounds(lb=-1, ub=1),
)
assert result.success
print(result)
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
  message: CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
  success: True
   status: 0
      fun: -4.020346688237437
        x: [ 5.139e-01 -1.000e+00 -1.000e+00]
      nit: 5
      jac: [-1.222e-08  3.839e+00  4.398e+00]
     nfev: 7
     njev: 7
 hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>

如果你真的想要一个符号解来求 a,而 bc 都等于 -1,那么你可以反向操作,把 bc 代入第一个雅可比矩阵的部分,然后解出 a

jac_rhs = jac[0].subs({
    symbols['b']: result.x[1],
    symbols['c']: result.x[2],
})
print(jac_rhs)
soln = sympy.solve(jac_rhs, symbols['a'],
                   dict=True, numerical=False, simplify=False, rational=True)
pprint(soln)
0.6*a**2 + 3.0*a - 1.7
[{a: -5/2 + sqrt(327)/6}, {a: -sqrt(327)/6 - 5/2}]

撰写回答