极值的logit和逆logit函数

2024-04-26 21:07:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要logit和逆logit函数,以便logit(inv_logit(n)) == n。我用的是numpy,这是我所拥有的:

import numpy as np
def logit(p):
    return np.log(p) - np.log(1 - p)

def inv_logit(p):
    return np.exp(p) / (1 + np.exp(p))

以下是价值观:

print logit(inv_logit(2)) 
2.0 

print logit(inv_logit(10))
10.0 

print logit(inv_logit(20))
20.000000018 #well, pretty close

print logit(inv_logit(50))
Warning: divide by zero encountered in log
inf 

现在让我们测试负数

print logit(inv_logit(-10))
-10.0 
print logit(inv_logit(-20))
-20.0 
print logit(inv_logit(-200))
-200.0 
print logit(inv_logit(-500))
-500.0 
print logit(inv_logit(-2000))
Warning: divide by zero encountered in log
-inf 

所以我的问题是:实现这些功能的正确方法是什么,以便要求logit(inv_logit(n)) == n在尽可能宽的范围内(至少是[-1e4;1e4)对任何n都有效?

另外(我肯定这和第一个有关),为什么我的函数在负值的情况下比正值更稳定?


Tags: 函数innumpylogbyreturndefnp
3条回答

有一种方法可以实现这些函数,使它们在广泛的值范围内保持稳定,但它涉及到根据参数区分情况。

例如inv_logit函数。你的公式“np.exp(p)/(1+np.exp(p))”是正确的,但对于大p会溢出。如果你用np.exp(p)除以分子和分母,你会得到等价的表达式

1. / (1. + np.exp(-p))

不同之处在于,对于大正值p,这个函数不会溢出。但是对于大负值p,这个函数会溢出。因此,一个稳定的实现可以如下:

def inv_logit(p):
    if p > 0:
        return 1. / (1. + np.exp(-p))
    elif p <= 0:
        np.exp(p) / (1 + np.exp(p))
    else:
        raise ValueError

这是LIBLINEAR库(可能还有其他库)中使用的策略。

要么使用

一。 支持任意精度浮点运算的bigfloat包。

2。 SymPy符号数学包。我将举两个例子:

首先,bigfloat:

http://packages.python.org/bigfloat/

下面是一个简单的例子:

from bigfloat import *
def logit(p):
    with precision(100000):
        return log(p)- log(1 -BigFloat(p))

def inv_logit(p):
    with precision(100000):
        return exp(p) / (1 + exp(p))

int(round(logit(inv_logit(12422.0))))
# gives 12422
int(round(logit(inv_logit(-12422.0))))
# gives -12422

这真的很慢。你可以考虑重新构造你的问题,并做一些分析部分。像这样的案例在实际问题中很少见-我很好奇你在处理什么样的问题。

安装示例:

wget http://pypi.python.org/packages/source/b/bigfloat/bigfloat-0.3.0a2.tar.gz
tar xvzf bigfloat-0.3.0a2.tar.gz 
cd bigfloat-0.3.0a2
as root:
python setup.py install

关于为什么你的功能在负数的情况下会更好。考虑:

>>> float(inv_logit(-15))
3.059022269256247e-07

>>> float(inv_logit(15))
0.9999996940977731

在第一种情况下,浮点数很容易表示这个值。小数点移动,使前导零:0.0000。。。不需要存储。在第二种情况下,所有前导的0.999都需要存储,因此在稍后在logit()中执行1-p时,需要所有额外的精度才能获得精确的结果。

下面是符号数学方法(速度明显更快!)以下内容:

from sympy import *
def inv_logit(p):
    return exp(p) / (1 + exp(p))
def logit(p):
    return log(p)- log(1 -p)

x=Symbol('x')
expr=logit(inv_logit(x))
# expr is now:
# -log(1 - exp(x)/(1 + exp(x))) + log(exp(x)/(1 + exp(x)))
# rewrite it: (there are many other ways to do this. read the doc)
# you may want to make an expansion (of some suitable kind) instead.
expr=cancel(powsimp(expr)).expand()
# it is now 'x'

# just evaluate any expression like this:    
result=expr.subs(x,123.231)

# result is now an equation containing: 123.231
# to get the float: 
result.evalf()

Sympy在这里找到http://docs.sympy.org/。在ubuntu中它是通过synaptic找到的。

你正在接近IEEE 754双精度浮点的精度限制。如果需要更大的范围和更精确的域,则需要使用更高精度的数字和操作。

>>> 1 + np.exp(-37)
1.0
>>> 1 + decimal.Decimal(-37).exp()
Decimal('1.000000000000000085330476257')

相关问题 更多 >