避免在此numpy logsumexp计算中产生无限值
接着之前提到的关于“值错误:真值模糊”的问题,我正在编辑这个logsumexp函数,具体代码可以在这里找到:https://github.com/scipy/scipy/blob/v0.14.0/scipy/misc/common.py#L18
这样做的原因有两个:
- 我想自己选择最大值,而不仅仅是数组中的最大值。
- 我想加一个条件,确保在从每个元素中减去最大值后,差值不会低于某个特定的阈值。
这是我最终的代码。代码本身没有问题——但有时候它仍然会返回无穷大!
def mylogsumexp(self, a, is_class, maxaj=None, axis=None, b=None):
threshold = -sys.float_info.max
a = asarray(a)
if axis is None:
a = a.ravel()
else:
a = rollaxis(a, axis)
if is_class == 1:
a_max = a.max(axis=0)
else:
a_max = maxaj
if b is not None:
b = asarray(b)
if axis is None:
b = b.ravel()
else:
b = rollaxis(b, axis)
#out = log(sum(b * exp(threshold if a - a_max < threshold else a - a_max), axis=0))
out = np.log(np.sum(b * np.exp( np.minimum(a - a_max, threshold)), axis=0))
else:
out = np.log(np.sum(np.exp( np.minimum(a - a_max, threshold)), axis=0))
out += a_max
1 个回答
1
你可以使用 np.clip
来限制一个数组的最大值和最小值:
>>> arr = np.arange(10)
>>> np.clip(arr, 3, 7)
array([3, 3, 3, 3, 4, 5, 6, 7, 7, 7])
在这个例子中,所有大于7的值都会被限制在7;所有小于3的值都会被设置为3。
如果我理解你的代码没错,你可能想把
out = np.log(np.sum(b * np.exp( np.minimum(a - a_max, threshold)), axis=0))
替换成
out = np.log(np.sum(b * np.exp( np.clip(a - a_max, threshold, maximum)), axis=0))
其中 maximum
是你想要的最大值。