在sympy中分解多项式

10 投票
3 回答
7713 浏览
提问于 2025-04-17 17:43

我正在做一个非常简单的概率计算,想从一组字母A-Z中选出X、Y、Z这几个字母,并且每个字母都有对应的概率x、y、z。

因为公式比较复杂,为了更好地处理这些公式,我想用sympy简化(或者说合并因式分解,我不太确定具体的定义)这些多项式表达式。

所以,我有这样一个(从A-Z中选出X、Y、Z的简单概率计算表达式,带有对应的概率x、y、z)

import sympy as sp

x, y, z = sp.symbols('x y z')

expression = (
    x * (1 - x) * y * (1 - x - y) * z +
    x * (1 - x) * z * (1 - x - z) * y +

    y * (1 - y) * x * (1 - y - x) * z +
    y * (1 - y) * z * (1 - y - z) * x +

    z * (1 - z) * y * (1 - z - y) * x +
    z * (1 - z) * x * (1 - z - x) * y
)

我想得到类似这样的结果

x * y * z * (6 * (1 - x - y - z) + (x + y) ** 2 + (y + z) ** 2 + (x + z) ** 2)

一个多项式,重写成尽量少的运算(+-***等)


我尝试了factor()collect()simplify(),但是结果和我预期的不太一样。大多数情况下我得到的是

2*x*y*z*(x**2 + x*y + x*z - 3*x + y**2 + y*z - 3*y + z**2 - 3*z + 3)

我知道sympy可以将多项式合并成简单的形式:

sp.factor(x**2 + 2*x*y + y**2)  # gives (x + y)**2

但我该如何让sympy从上面的表达式中合并这些多项式呢?


如果在sympy中这是一项不可能的任务,是否还有其他选择呢?

3 个回答

2

我之前也遇到过类似的问题,最后自己实现了一个解决方案,后来才发现了这个方法。我的方法似乎在减少操作次数方面做得更好。不过,我的方法也采用了暴力破解的方式,遍历所有变量的组合。因此,随着变量数量的增加,它的运行时间会呈现超指数级增长。另一方面,我成功在一个有7个变量的方程上运行了它,虽然时间不算太长,但也远没有达到实时处理的水平。

可能有一些方法可以优化搜索过程,减少一些不必要的分支,但我没有去研究这些。欢迎大家提出进一步的优化建议。

def collect_best(expr, measure=sympy.count_ops):
    # This method performs sympy.collect over all permutations of the free variables, and returns the best collection
    best = expr
    best_score = measure(expr)
    perms = itertools.permutations(expr.free_symbols)
    permlen = np.math.factorial(len(expr.free_symbols))
    print(permlen)
    for i, perm in enumerate(perms):
        if (permlen > 1000) and not (i%int(permlen/100)):
            print(i)
        collected = sympy.collect(expr, perm)
        if measure(collected) < best_score:
            best_score = measure(collected)
            best = collected
    return best

def product(args):
    arg = next(args)
    try:
        return arg*product(args)
    except:
        return arg

def rcollect_best(expr, measure=sympy.count_ops):
    # This method performs collect_best recursively on the collected terms
    best = collect_best(expr, measure)
    best_score = measure(best)
    if expr == best:
        return best
    if isinstance(best, sympy.Mul):
        return product(map(rcollect_best, best.args))
    if isinstance(best, sympy.Add):
        return sum(map(rcollect_best, best.args))

为了说明性能,这篇论文(需要付费,抱歉)中有7个公式,这些公式是7个变量的5次多项式,最多有29项和158个操作。在应用了rcollect_best和@smichr的iflfactor之后,这7个公式中的操作次数分别是:

[6, 15, 100, 68, 39, 13, 2]

以及

[32, 37, 113, 73, 40, 15, 2]

。其中,iflfactor的操作次数比rcollect_best多了433%对于其中一个公式。此外,扩展后的公式中的操作次数是:

[39, 49, 158, 136, 79, 27, 2]
3

据我所知,没有一个函数能完全做到这一点。我觉得这其实是个很难的问题。你可以看看这个链接:减少简单表达式的操作次数,里面有一些讨论。

不过,SymPy里有很多简化函数可以尝试。其中一个你没提到的函数是 gcd_terms,它可以提取出一个符号的最大公约数,而不需要展开表达式。它的结果是

>>> gcd_terms(expression)
x*y*z*((-x + 1)*(-x - y + 1) + (-x + 1)*(-x - z + 1) + (-y + 1)*(-x - y + 1) + (-y + 1)*(-y - z + 1) + (-z + 1)*(-x - z + 1) + (-z + 1)*(-y - z + 1))

另一个有用的函数是 .count_ops,它可以计算一个表达式中的操作次数。例如

>>> expression.count_ops()
47
>>> factor(expression).count_ops()
22
>>> e = x * y * z * (6 * (1 - x - y - z) + (x + y) ** 2 + (y + z) ** 2 + (x + z) ** 2)
>>> e.count_ops()
18

(注意,e.count_ops() 计算的结果和你自己算的不一样,因为SymPy会自动把 6*(1 - x - y - z) 展开成 6 - 6*x - 6*y - 6*z)。

还有其他一些有用的函数:

  • cse:对表达式进行公共子表达式消除。有时候你可以先简化每个部分,然后再组合起来。这也有助于避免重复计算。

  • horner:对多项式应用 霍纳法则。如果多项式只有一个变量,这样可以减少操作次数。

  • factor_terms:和 gcd_terms 类似。我其实不太清楚它们之间的具体区别。

需要注意的是,默认情况下,simplify 会尝试多种简化方式,并返回通过 count_ops 最小化的结果。

7

这次把几种方法结合在一起,得到了一个不错的结果。很有意思的是,看看这种策略在你生成的方程中是否经常有效,还是说,正如名字所暗示的,这只是这次的一个幸运结果。

def iflfactor(eq):
    """Return the "I'm feeling lucky" factored form of eq."""
    e = Mul(*[horner(e) if e.is_Add else e for e in
        Mul.make_args(factor_terms(expand(eq)))])
    r, e = cse(e)
    s = [ri[0] for ri in r]
    e = Mul(*[collect(ei.expand(), s) if ei.is_Add else ei for ei in
        Mul.make_args(e[0])]).subs(r)
    return e

>>> iflfactor(eq)  # using your equation as eq
2*x*y*z*(x**2 + x*y + y**2 + (z - 3)*(x + y + z) + 3)
>>> _.count_ops()
15

顺便提一下,factor_terms和gcd_terms之间的区别在于,factor_terms会更努力地提取出共同的项,同时保持表达式的原始结构,就像你手动操作时那样(也就是说,寻找可以提取的加法中的共同项)。

>>> factor_terms(x/(z+z*y)+x/z)
x*(1 + 1/(y + 1))/z
>>> gcd_terms(x/(z+z*y)+x/z)
x*(y*z + 2*z)/(z*(y*z + z))

说实话,

Chris

撰写回答