在Python中测试数学表达式的等价性
我在Python里有两个字符串,
A m * B s / (A m + C m)
和
C m * B s / (C m + A m)
这两个字符串分别代表无序集合(A, C)和无序集合(B)。这里的m和s表示可以在同一类单位之间互换,但不能和其他单位互换。
到目前为止,我一直在对A、B和C进行排列组合,并使用eval和SymPy的==运算符来测试它们的相等性。这种方法有几个缺点:
- 对于更复杂的表达式,我需要生成大量的排列组合(在我的例子中是8个嵌套的for循环)
- 我需要把A、B、C定义为符号,但当我不知道会有什么参数时,这样做并不理想(所以我必须生成所有可能的参数,这样效率极低,还会搞乱我的变量命名空间)
有没有什么更Pythonic的方法来测试这种相等性?它应该能处理任意表达式。
5 个回答
如果你把一个字符串传给SymPy的sympify()
函数,它会自动为你创建符号(你不需要自己一个个定义它们)。
>>> from sympy import *
>>> x
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'x' is not defined
>>> sympify("x**2 + cos(x)")
x**2 + cos(x)
>>> sympify("diff(x**2 + cos(x), x)")
2*x - sin(x)
与其遍历所有可能的排列,不如假设已经存在一个排列,然后尝试去构建它。我认为,如果方法得当,算法失败就意味着这个排列根本不存在。
下面是这个想法在上述表达式中的大致思路:
设定:
str1 = "A m * B s / (A m + C m)"
str2 = "C m * B s / (C m + A m)"
我们在寻找一个集合(A, C)的排列,使得这些表达式看起来一样。根据它们在str2中第一次出现的顺序,把A和C重新标记为X1和X2,所以:
X1 = C
X2 = A
因为C在str2中出现在A之前。接下来,创建一个数组Y,使得y[i]是str1中A或C第一次出现的第i个符号。所以:
Y[1] = A
Y[2] = C
因为A在str1中出现在C之前。
现在从str2构建str3,把A和C替换为X1和X2:
str3 = "X1 m * B s / (X1 m + X2 m)"
然后开始用Y[i]替换Xi。首先,X1变成Y[1]=A:
str3_1 = "A m * Bs / (A m + X2 m)"
在这个阶段,比较str3_1和str1,直到遇到任何Xi的第一次出现,这里是X2,因为这两个字符串相等:
str3_1[:18] = "A m * B s / (A m + "
str1[:18] = "A m * B s / (A m + "
你有机会构建这个排列。如果它们不相等,你就证明了没有合适的排列存在(因为任何排列至少需要做这个替换),可以得出它们不相等。但它们是相等的,所以你继续下一步,把X2替换为Y[2]=C:
str3_2 = "A m * B s / (A m + C m)"
这和str1相等,所以你得到了你的排列(A->C, C->A),并且证明了这些表达式是等价的。
这只是对特定情况的算法演示,但应该可以推广。不确定你能把它简化到什么程度,但应该比对n个变量生成所有排列的n!要快。
如果我理解单位的意义是正确的,它们限制了哪些变量可以通过排列互换。所以在上述表达式中,A可以替换为C,因为它们都有'm'单位,但不能替换为B,因为B有's'单位。你可以这样处理:
从str1和str2构建表达式str1_m和str2_m,去掉所有没有m单位的符号,然后对str1_m和str2_m执行上述算法。如果构建失败,说明没有排列存在。如果构建成功,保留这个排列(称为m-排列),然后从str1和str2构建str1_s和str2_s,去掉所有没有s单位的符号,再次对str1_s和str2_s执行算法。如果构建失败,它们不等价。如果成功,最终的排列将是m-排列和s-排列的组合(虽然你可能根本不需要构建它,只需要知道它存在即可)。
这里有一个简化的方法,基于我之前的回答。
这个想法是,如果两个表达式在排列下是等价的,那么把一个表达式变成另一个表达式的排列,必须把第一个字符串中第i个符号(按照第一次出现的顺序)映射到第二个字符串中第i个符号(同样按照第一次出现的顺序)。这个原则可以用来构造一个排列,把它应用到第一个字符串上,然后检查它和第二个字符串是否相等。如果相等,那它们就是等价的;如果不相等,那就不是。
下面是一个可能的实现:
import re
# Unique-ify list, preserving order
def uniquify(l):
return reduce(lambda s, e: s + ([] if e in s else [e]), l, [])
# Replace all keys in replacements with corresponding values in str
def replace_all(str, replacements):
for old, new in replacements.iteritems():
str = str.replace(old, new)
return str
class Expression:
units = ["m", "s"]
def __init__(self, exp):
self.exp = exp
# Returns a list of symbols in the expression that are preceded
# by the given unit, ordered by first appearance. Assumes the
# symbol and unit are separated by a space. For example:
# Expression("A m * B s / (A m + C m)").symbols_for_unit("m")
# returns ['A', 'C']
def symbols_for_unit(self, unit):
sym_re = re.compile("(.) %s" % unit)
symbols = sym_re.findall(self.exp)
return uniquify(symbols)
# Returns a string with all symbols that have units other than
# unit "muted", that is replaced with the empty string. Example:
# Expression("A m * B s / (A m + C m)").mute_symbols_for_other_units("m")
# returns "A m * s / (A m + C m)"
def mute_symbols_for_other_units(self, unit):
other_units = "".join(set(self.units) - set(unit))
return re.sub("(.) ([%s])" % "".join(other_units), " \g<2>", self.exp)
# Returns a string with all symbols that have the given unit
# replaced with tokens of the form $0, $1, ..., by order of their
# first appearance in the string, and all other symbols muted.
# For example:
# Expression("A m * B s / (A m + C m)").canonical_form("m")
# returns "$0 m * s / ($0 m + $1 m)"
def canonical_form(self, unit):
symbols = self.symbols_for_unit(unit)
muted_self = self.mute_symbols_for_other_units(unit)
for i, sym in enumerate(symbols):
muted_self = muted_self.replace("%s %s" % (sym, unit), "$%s %s" % (i, unit))
return muted_self
# Define a permutation, represented as a dictionary, according to
# the following rule: replace $i with the ith distinct symbol
# occurring in the expression with the given unit. For example:
# Expression("C m * B s / (C m + A m)").permutation("m")
# returns {'$0':'C', '$1':'A'}
def permutation(self, unit):
enum = enumerate(self.symbols_for_unit(unit))
return dict(("$%s" % i, sym) for i, sym in enum)
# Return a string produced from the expression by first converting it
# into canonical form, and then performing the replacements defined
# by the given permutation. For example:
# Expression("A m * B s / (A m + C m)").permute("m", {"$0":"C", "$1":"A"})
# returns "C m * s / (C m + A m)"
def permute(self, unit, permutation):
new_exp = self.canonical_form(unit)
return replace_all(new_exp, permutation)
# Test for equality under permutation and muting of all other symbols
# than the unit provided.
def eq_under_permutation(self, unit, other_exp):
muted_self = self.mute_symbols_for_other_units(unit)
other_permuted_str = other_exp.permute(unit, self.permutation(unit))
return muted_self == other_permuted_str
# Test for equality under permutation. This is done for each of
# the possible units using eq_under_permutation
def __eq__(self, other):
return all([self.eq_under_permutation(unit, other) for unit in self.units])
e1 = Expression("A m * B s / (A m + C m)")
e2 = Expression("C m * B s / (C m + A m)")
e3 = Expression("A s * B s / (A m + C m)")
f1 = Expression("A s * (B s + D s) / (A m + C m)")
f2 = Expression("A s * (D s + B s) / (C m + A m)")
f3 = Expression("D s")
print "e1 == e2: ", e1 == e2 # True
print "e1 == e3: ", e1 == e3 # False
print "e2 == e3: ", e2 == e3 # False
print "f1 == f2: ", f1 == f2 # True
print "f1 == f3: ", f1 == f3 # False
正如你所指出的,这个方法检查字符串在排列下的等价性,而不考虑数学上的等价性,但这已经是解决问题的一半了。如果你有一个标准形式的数学表达式,你可以在两个标准形式的表达式上使用这个方法。也许sympy中的某个简化功能可以帮到你。