如何在Python中计算1减去给定小数的指数的对数
我正在进行一个概率计算。我有很多非常非常小的数字,我想从1中减去这些数字,并且希望能做到准确。我可以准确计算这些小数字的对数。到目前为止,我的策略是这样的(使用numpy库):
给定一个包含小数字对数的数组 x
,计算:
y = numpy.logaddexp.reduce(x)
现在我想计算类似 1-exp(y)
或者更好的是 log(1-exp(y))
,但我不确定怎么做才能不失去所有的精度。
实际上,即使是 logaddexp
函数也遇到了精度问题。向量 x
中的值可以从 -2 到 -800,甚至更小。上面提到的向量 y
基本上会有一整段数字在 1e-16 附近,这个值是数据类型的 eps
。所以,例如,准确计算的数据可能看起来像这样:
In [358]: x
Out[358]:
[-5.2194676211172837,
-3.9050377656308362,
-3.1619783292449615,
-2.71289594096134,
-2.4488395891021639,
-2.3129210706827568,
-2.2709987626652346,
-2.3007776073511259,
-2.3868404149802434,
-2.5180718876609163,
-2.68619816583087,
-2.8849022632856958,
-3.1092603032627686,
-3.3553673369747834,
-3.6200806272462351,
-3.9008385919463073,
-4.1955300857178379,
-4.5023981074719899,
-4.8199676154248081,
-5.1469905756384904,
-5.4824035553480428,
-5.8252945959126876,
-6.174877049340779,
-6.5304687083067563,
-6.8914750074202473,
-7.25737538919104,
-7.6277121540338797,
-8.0020812775389558,
-8.3801247986220773,
-8.7615244716292437,
-9.1459964426584435,
-9.5332867613176404,
-9.9231675781398394,
-10.315433907978701,
-10.709900863130784,
-11.106401278287066,
-11.50478366390567,
-11.904910436107656,
-12.30665638039909,
-12.709907313918777,
-13.114558916892051,
-13.52051570882999,
-13.927690148982549,
-14.336001843810081,
-14.745376846921289,
-15.155747039147968,
-15.567049578271309,
-15.979226409456359,
-16.39222382873956,
-16.805992092998878,
-17.22048507074976,
-17.63565992888303,
-18.051476851117201,
-18.467898784496384,
-18.884891210740903,
-19.302421939667397,
-19.720460922243518,
-20.138980081145718,
-20.557953156947775,
-20.977355568292495,
-21.397164284594595,
-21.817357709992422,
-22.237915577412224,
-22.658818851739369,
-23.080049641202237,
-23.501591116172762,
-23.923427434676114,
-24.345543673975158,
-24.767925767665417,
-25.190560447772668,
-25.61343519140047,
-26.036538171518259,
-26.459858211524278,
-26.883384743252066,
-27.307107768123842,
-27.731017821180984,
-28.155105937748402,
-28.579363622513654,
-29.003782820820732,
-29.428355891997484,
-29.853075584553352,
-30.27793501309668,
-30.702927636836705,
-31.128047239545907,
-31.553287910869187,
-31.978644028878307,
-32.404110243774596,
-32.82968146265631,
-33.255352835270173,
-33.681119740674262,
-34.106977774747804,
-34.532922738484046,
-34.958950627012712,
-35.385057619298891,
-35.811240068471022,
-36.237494492735493,
-36.663817566835519,
-37.090206114019054,
-37.516657098479527,
-37.943167618239784,
-38.369734898447348,
-38.796356285056333,
-39.223029238868548,
-39.64975132991276,
-40.076520232137909,
-40.5033337184027,
-40.930189655741344,
-41.357086000888444,
-41.784020796047173,
-42.210992164885965,
-42.637998308748706,
-43.065037503066776,
-43.492108093959985,
-43.919208495015312,
-44.346337184233221,
-44.773492701130749,
-45.200673643993753,
-45.627878667267964,
-46.055106479082156,
-46.482355838895614,
-46.909625555262096,
-47.336914483704675,
-47.764221524695017,
-48.191545621730768,
-48.618885759506213,
-49.04624096217151,
-49.473610291673936,
-49.900992846179292,
-50.328387758566748,
-50.755794194994508,
-51.183211353532613,
-51.610638462858901,
-52.0380747810147,
-52.46551959421754,
-52.892972215728378,
-53.320431984769073,
-53.747898265489198,
-54.175370445978274,
-54.602847937323247,
-55.030330172705362,
-55.457816606538813,
-55.885306713645889,
-56.312799988467418,
-56.740295944308855,
-57.167794112617116,
-57.59529404228897,
-58.02279529900909,
-58.450297464615232,
-58.877800136490578,
-59.305302926981085,
-59.732805462838542,
-60.160307384683506,
-60.587808346493375,
-61.015308015110463,
-61.442806069768608,
-61.87030220164138,
-62.297796113406662,
-62.725287518829532,
-63.15277614236129,
-63.580261718755196,
-64.007743992695964,
-64.435222718445743,
-64.862697659501919,
-65.290168588270035,
-65.717635285748088,
-66.14509754122389,
-66.572555151982783,
-67.000007923029216,
-67.427455666815376,
-67.854898202982099,
-68.282335358110231,
-68.709766965479957,
-69.137192864839108,
-69.564612902180784,
-69.992026929530198,
-70.419434804735829,
-70.8468363912732,
-71.274231558051156,
-71.701620179229167,
-72.129002134037705,
-72.556377306608397,
-72.983745585807242,
-73.411106865077045,
-73.838461042282461,
-74.265808019561746,
-74.693147703185559,
-75.120480003416901,
-75.547804834380145,
-75.97512211393132,
-76.402431763534764,
-76.829733708143749,
-77.257027876085431,
-77.684314198948414,
-78.111592611476681,
-78.538863051464546,
-78.966125459656723,
-79.393379779652037,
-79.820625957809625,
-80.24786394315754,
-80.675093687306912,
-81.102315144366912]
然后我尝试计算指数的对数和:
In [359]: np.logaddexp.accumulate(x)
Out[359]:
array([ -5.21946762e+00, -3.66710221e+00, -2.68983273e+00,
-2.00815067e+00, -1.51126604e+00, -1.14067818e+00,
-8.60829425e-01, -6.48188808e-01, -4.86276416e-01,
-3.63085873e-01, -2.69624488e-01, -1.99028599e-01,
-1.45996863e-01, -1.06408884e-01, -7.70565672e-02,
-5.54467248e-02, -3.96506186e-02, -2.81859503e-02,
-1.99225261e-02, -1.40061296e-02, -9.79701394e-03,
-6.82045164e-03, -4.72733966e-03, -3.26317960e-03,
-2.24396350e-03, -1.53767347e-03, -1.05026994e-03,
-7.15209142e-04, -4.85690052e-04, -3.28980607e-04,
-2.22305294e-04, -1.49890553e-04, -1.00858788e-04,
-6.77380054e-05, -4.54139175e-05, -3.03974537e-05,
-2.03154477e-05, -1.35581905e-05, -9.03659252e-06,
-6.01552344e-06, -3.99984336e-06, -2.65671945e-06,
-1.76283376e-06, -1.16860435e-06, -7.73997496e-07,
-5.12213574e-07, -3.38706792e-07, -2.23809375e-07,
-1.47785898e-07, -9.75226648e-08, -6.43149957e-08,
-4.23904687e-08, -2.79246430e-08, -1.83858489e-08,
-1.20995365e-08, -7.95892319e-09, -5.23300609e-09,
-3.43929670e-09, -2.25953475e-09, -1.48391255e-09,
-9.74194956e-10, -6.39351406e-10, -4.19466218e-10,
-2.75121795e-10, -1.80397409e-10, -1.18254918e-10,
-7.74993004e-11, -5.07775611e-11, -3.32619009e-11,
-2.17835737e-11, -1.42634249e-11, -9.33764336e-12,
-6.11190167e-12, -3.99989955e-12, -2.61737204e-12,
-1.71253165e-12, -1.12043465e-12, -7.33052079e-13,
-4.79645919e-13, -3.13905885e-13, -2.05519681e-13,
-1.34650094e-13, -8.83173582e-14, -5.80300378e-14,
-3.82338678e-14, -2.52963381e-14, -1.68421145e-14,
-1.13181549e-14, -7.70918073e-15, -5.35155125e-15,
-3.81152630e-15, -2.80565548e-15, -2.14872312e-15,
-1.71971577e-15, -1.43957518e-15, -1.25665732e-15,
-1.13722927e-15, -1.05925916e-15, -1.00835857e-15,
-9.75131524e-16, -9.53442707e-16, -9.39286186e-16,
-9.30046550e-16, -9.24016349e-16, -9.20080954e-16,
-9.17512772e-16, -9.15836886e-16, -9.14743318e-16,
-9.14029759e-16, -9.13564174e-16, -9.13260398e-16,
-9.13062204e-16, -9.12932898e-16, -9.12848539e-16,
-9.12793505e-16, -9.12757603e-16, -9.12734183e-16,
-9.12718905e-16, -9.12708939e-16, -9.12702438e-16,
-9.12698198e-16, -9.12695432e-16, -9.12693627e-16,
-9.12692451e-16, -9.12691683e-16, -9.12691183e-16,
-9.12690856e-16, -9.12690643e-16, -9.12690504e-16,
-9.12690414e-16, -9.12690355e-16, -9.12690316e-16,
-9.12690291e-16, -9.12690275e-16, -9.12690264e-16,
-9.12690257e-16, -9.12690252e-16, -9.12690249e-16,
-9.12690248e-16, -9.12690246e-16, -9.12690245e-16,
-9.12690245e-16, -9.12690245e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16,
-9.12690244e-16, -9.12690244e-16, -9.12690244e-16])
这最终导致:
In [360]: np.logaddexp.reduce(x)
Out[360]: -9.1269024387687033e-16
所以我的精度已经被破坏了。有没有什么办法可以解决这个问题?
6 个回答
我建议在接近0和1的地方,分别用它们的泰勒级数来替代exp()和log()。这样做可以避免因为使用大数字而失去精度(我刚才居然把1称为大数字 :^)。可以使用拉格朗日余项公式,或者用一些保留值来判断,从什么时候开始误差会超过你的精度。
更新:
在Python 2.7中,math.expm1(exp(x)-1
)和math.log1p(log(1+x)
)可以为你处理这些问题,只要你使用的平台的C库的精度(通常是双精度)足够。如果不够,你就得使用一些特殊的数学软件(x86的FPU可以进行扩展精度计算)。
我对Python了解不多,大部分工作都是用Java来做。不过我觉得,直接对所有值同时进行log-sum-exp的处理会比用numpy.logaddexp.accumulate两两处理要好。
在谷歌上快速搜索了一下,发现Python库中有一个候选项:scipy.misc.logsumexp。
不过其实自己编程实现这个也不难:
logsumexp(probs) := max(probs) + log(sum[i](exp(probes[i]-max(probs))))
大概可以这样写:
maxValue = -Inf;
for x in probs
if x > maxValue then maxValue = x
expSum = 0
for x in probs
expSum += exp(x - maxValue)
return log(expSum)
返回的单个值,比如说p,就是所有概率probs的和的对数。注意,如果输入的概率中最大值和较小值之间差距很大,小的值会被忽略,前提是它们的贡献相对于大数来说非常小,这在大多数应用中应该没问题。
如果有很多小值,你可以使用更复杂的策略来让这些小值也能算上,比如说有很多个小数相加,比如probe = 0.5 + 1E-7 + 1E-7……加起来可能有一百万次,总和是0.1。你可以选择把单个的和分成几个部分,先把大致相同规模的值加在一起,然后再合并。或者你可以用一个中间的支点值来代替最大值,但要确保最大的值不要太大,因为那样的话,exp(probs[i] - pivot)可能会导致溢出。
完成这些后,你还需要计算log(1-exp(p))。
为此,我找到了一份文档,里面描述了一种尽量减少精度损失的方法,使用的是大多数编程语言数学库中常见的逻辑函数。
Maechler M, 准确计算 log(1 − exp(− |a|)),由Rmpfr包评估,2012
关键是根据输入值a的不同,使用两种可能的方法之一。
定义:
log1mexp(a) := log(1-exp(a)) ### function that we seek to implement.
log1p(a) := log(1+a) # function provided by common math libraries.
exp1m(a) := exp(a) - 1 # function provided by common math libraries.
有一种明显的方法可以用log1p来实现log1mexp:
log1mexp(a) := log1p(-exp(a))
使用exp1m你可以这样做:
log1mexp(a) := log(-expm1(a))
当a < log(.5)时,你应该使用log1p的方法,而当a >= log(.5)时,则使用expm1。
log1mexp(a) := (a < log(.5)) ? log1p(-exp(a)) : log(-expm1(a)).
更多信息请参考外部链接。
在Python 2.7中,我们增加了一个叫做 math.expm1() 的功能,专门用来处理这种情况:
>>> from math import exp, expm1
>>> exp(1e-5) - 1 # gives result accurate to 11 places
1.0000050000069649e-05
>>> expm1(1e-5) # result accurate to full precision
1.0000050000166668e-05
另外,还有一个叫 math.fsum() 的功能,可以在求和时保持精度,不会丢失数据:
>>> sum([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1])
0.9999999999999999
>>> fsum([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1])
1.0
最后,如果这些都不管用,你可以使用 decimal模块,它支持超高精度的数学运算:
>>> from decimal import *
>>> getcontext().prec = 200
>>> (1 - 1 / Decimal(7000000)).ln()
Decimal('-1.4285715306122546161332083855139723669559469615692284955124609122046580004888309867906750714454869716398919778588515625689415322789136206397998627088895481989036005482451668027002380442299229191323673E-7')