优化数字使得四舍五入后的总和为零
在Python3中,如果我有一个字典,比如说:
d = {'C1': 0.68759, 'C2': -0.21432, 'H1': 0.49062, 'H2': -0.13267, 'H3': 0.08092, 'O1': -0.8604}
我还有一个方程:
n1*d['C1'] + n2*d['H2'] + n3*d['C2'] + n4*d['H3'] + n5*d['O1'] + n6*d['H1'] = small_number
我希望在把每个d[key]
的值四舍五入到5位数字时,这个方程的结果等于零。这里的n1, n2, n3...是一些正整数。我还想限制每个字典值的变化量,比如每个值最多只能变动±0.0005。在这个具体的例子中:
n1, n2, n3, n4, n6 = 2, 2, 2, 4, 4, 2
有没有简单且通用的方法来实现这个呢?
这里的“通用”是指字典d可以包含更多的键值对,而这些标量n1, n2,...也可能不同(但总是正整数)。字典d的长度和标量n的数量总是相等。可以合理地假设,在每种情况下,small_number的值大约是1E-5。如果需要,我也可以把四舍五入的位数增加到±1位,超过5位。
我考虑过尝试从字典d中的某个特定键中减去方程的非零部分,但如果方程的总和是1E-5,并且像我这个例子中没有标量等于1时,我就会遇到问题。我也可以使用scipy.optimize.minimize
,并以某种方式定义一个目标函数,作为字典d和标量值n的函数,但我不太确定该怎么正确地做到这一点。
标量值n不应该改变。
2 个回答
我找到了一种对我有效的解决办法。如果剩下的数字 small_number 的第一个数字和某个系数 n 相等,我可以直接把 small_number 从与那个系数对应的 d[key]
中减去。
如果 small_number 的第一个数字和任何系数 n 都不相等,我可以选择最小的系数 n,然后用 small_number 除以这个系数,再把结果从字典中对应的值中减去,同时要四舍五入到需要的位数。虽然这个过程有点复杂,但我找到的有效代码是:
import numpy as np
# Atom types and charges of each atom
d = {'C1': 0.68759, 'C2': -0.21432, 'H1': 0.49062, 'H2': -0.13267, 'H3': 0.08092, 'O1': -0.8604}
# Number of each atom occuring in molecule
n = (2, 2, 2, 4, 4, 2)
# calculate leftover charge
def calc_sum(n, d):
repeats = np.repeat(np.array(list(d.values())), n)
return round(np.sum(repeats), 5)
# Get first digit of leftover charge
def get_divisor(final_sum):
final_sum = str(final_sum)
final_sum = final_sum.replace("-", '')
final_sum = final_sum.replace("0", '')
final_sum = final_sum.replace(".", '')
return int(final_sum[0])
# Get the power of a number in scientific notation; assumed to be less than 1E-04
def get_power(number):
number = str(number)
return(int(number[-2:].replace('0', '')))
final_sum = calc_sum(n, d)
# quit if leftover charge is too large or too small
if abs(final_sum) > 1E-04 or abs(final_sum) < 1E-05:
exit()
divisor = get_divisor(final_sum)
# if the divisor is equal to the number of one atom type, then subtract the leftover charge from
# the first atom type found with that number of atoms
if divisor in n:
divisor_index = n.index(divisor)
divisor_key = list(d)[divisor_index]
new_value = round(d[divisor_key] - final_sum/divisor, 5)
d[divisor_key] = new_value
# if the divisor is not equal to any items in n, divide leftover charge
# by the smallest number n, subtract this value from the d[key] that corresponds to the
# smallest number n, and round to appropriate digits so that the linear combination is zero
else:
min_value_n = min(n)
min_value_index = n.index(min_value_n)
to_subtract = final_sum/min_value_n
round_to = get_power(to_subtract)
divisor_key = list(d)[min_value_index]
new_value = round(d[divisor_key] - to_subtract, round_to)
d[divisor_key] = new_value
print("Final rounded linear combination: ", calc_sum(n, d)) # Will be zero if everything is correct
print(d)
我相信这对任何字典 d 和线性组合值 n 都适用,只要 len(n) = len(d)
,并且线性组合的结果在 1E-04 和 1E-05 的范围内。
什么是“好的解决方案”?这并不像看起来那么简单。好的解决方案可能意味着错误率低,或者系数小,或者两者的加权组合。这段话展示了一个例子,说明了scipy是如何在尽量减少错误的同时,把系数当作没有成本限制的决策变量来处理的。
import numpy as np
from scipy.optimize import milp, Bounds, LinearConstraint
def solve(
d: dict[str, float],
small: float,
nmax: int = 20,
) -> list[int]:
N = len(d)
# Variables: n (positive integers), absolute error (continuous)
# Minimize absolute error.
c = np.concatenate((np.zeros(N), (1,)))
# n are integral, error is continuous
integrality = np.ones(N+1, dtype=np.uint8)
integrality[-1] = 0
# 1 <= n <= some reasonable max
bounds = Bounds(
lb=np.concatenate((np.ones(N), (-np.inf,))),
ub=np.concatenate((np.full(shape=N, fill_value=nmax), (+np.inf,))),
)
d_array = np.array(tuple(d.values()))
# n.d - small <= error -d, error >= -small
# -n.d + small <= error +d, error >= +small
constraints = LinearConstraint(
A=np.block([
[-d_array, np.ones(1)],
[+d_array, np.ones(1)],
]),
lb=[-small, small], ub=[np.inf, np.inf],
)
result = milp(
c=c, integrality=integrality, bounds=bounds, constraints=constraints,
)
if not result.success:
raise ValueError(result.message)
n, (error,) = np.split(result.x, (-1,))
return n, error
def demo() -> None:
small = 1e-5
d = {
'C1': 0.68759, 'C2': -0.21432,
'H1': 0.49062, 'H2': -0.13267,
'H3': 0.08092, 'O1': -0.86040,
}
n, error = solve(d=d, small=small)
print(
' + '.join(
f'{xn:.0f}*{xd:.5f}'
for xn, xd in zip(n, d.values())
),
f'= {small} + {error}'
)
if __name__ == '__main__':
demo()
1*0.68759 + 3*-0.21432 + 12*0.49062 + 18*-0.13267 + 20*0.08092 + 6*-0.86040 = 1e-05 + -0.0