Sympy多项式的KKT点最佳寻找方法
我正在尝试在一些低度多项式上运行梯度下降(GD),比如说三次多项式,变量数量是n,并且在一个被限制在超立方体[-1,1]^n的范围内。我想把GD的结束点和用Sympy计算的实际KKT点进行比较。
举个例子,我有一个包含三个变量a、b和c的Sympy表达式(也就是n=3),像这样:
0.2*a**3 - 0.1*a**2*b - 0.9*a**2*c + 0.5*a**2 - 0.1*a*b**2 - 0.4*a*b*c + 0.5*a*b - 0.6*a*c**2 - 0.5*a*c - 0.6*a - 0.1*b**3 + 0.4*b**2*c - 0.4*b**2 + 0.7*b*c**2 - 0.4*b*c + 0.9*b + 0.09*c**3 - 0.6*c**2 + 0.4*c + 0.22
那么,找到这个表达式所有KKT点的最佳、最快的方法是什么呢?并且需要满足以下约束条件:
-1 <= a <= 1
-1 <= b <= 1
-1 <= c <= 1
(理想情况下,搜索应该仅限于GD结束点附近,以节省计算时间,但我不太确定如何严格做到这一点)
1 个回答
3
先把这个表达式和它的雅可比矩阵(就是一种数学工具)预先编译成函数对象,然后把它们传递给 minimize
函数。需要注意的是,这个方法会比你想的梯度下降算法更聪明一些,但这可能对你有帮助,也可能没有。默认情况下,它使用的是Kraft的顺序最小二乘法。
import string
import numpy as np
import scipy.optimize
import sympy
expr_str = (
'0.2*a**3 - 0.1*a**2*b - 0.9*a**2*c + 0.5*a**2 - 0.1*a*b**2 - 0.4*a*b*c '
'+ 0.5*a*b - 0.6*a*c**2 - 0.5*a*c - 0.6*a - 0.1*b**3 + 0.4*b**2*c '
'- 0.4*b**2 + 0.7*b*c**2 - 0.4*b*c + 0.9*b + 0.09*c**3 - 0.6*c**2 + 0.4*c + 0.22'
)
symbols = {
c: sympy.Symbol(name=c, real=True, finite=True)
for c in string.ascii_lowercase
if c in expr_str
}
expr = sympy.parse_expr(s=expr_str, local_dict=symbols)
expr_callable = sympy.lambdify(
args=[symbols.values()], expr=expr,
)
jac = [
sympy.diff(expr, sym)
for sym in symbols.values()
]
jac_callable = sympy.lambdify(
args=[symbols.values()], expr=jac,
)
result = scipy.optimize.minimize(
fun=expr_callable,
jac=jac_callable,
x0=np.zeros(len(symbols)),
bounds=scipy.optimize.Bounds(lb=-1, ub=1),
)
assert result.success
print(result)
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
message: CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
success: True
status: 0
fun: -4.020346688237437
x: [ 5.139e-01 -1.000e+00 -1.000e+00]
nit: 5
jac: [-1.222e-08 3.839e+00 4.398e+00]
nfev: 7
njev: 7
hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>
如果你真的想要一个符号解来求 a
,而 b
和 c
都等于 -1,那么你可以反向操作,把 b
和 c
代入第一个雅可比矩阵的部分,然后解出 a
。
jac_rhs = jac[0].subs({
symbols['b']: result.x[1],
symbols['c']: result.x[2],
})
print(jac_rhs)
soln = sympy.solve(jac_rhs, symbols['a'],
dict=True, numerical=False, simplify=False, rational=True)
pprint(soln)
0.6*a**2 + 3.0*a - 1.7
[{a: -5/2 + sqrt(327)/6}, {a: -sqrt(327)/6 - 5/2}]