Numpy max函数在使用包含NaN的Decimal值时失败

4 投票
2 回答
1582 浏览
提问于 2025-04-18 18:43

当我们使用浮点数(float)时,一切都很顺利。

>>> import numpy as np
>>> np.max(1.2, np.nan)
>>> nan

但是,当我们使用十进制数(Decimal)时...

>>> import numpy as np
>>> import decimal as d 
>>> np.max([d.Decimal('1.2'), d.Decimal('NaN')])
>>> InvalidOperation: comparison involving NaN

有没有办法让带有NaN(不是一个数字)的十进制数好好地工作呢?

注意:

  • Python 2.7
  • Numpy 1.6.2

2 个回答

1

你可以把这个列表转换成一个NumPy数组,数据类型设置为float(浮点数)。这样,所有的NumPy函数都能正常使用了:

import numpy as np
import decimal as d

print np.max(np.array([0, 1, d.Decimal('nan')], dtype='float'))
print np.nanmax(np.array([0, 1, d.Decimal('nan')], dtype='float'))

输出结果:

nan
1.0
4

嗯……如果里面有至少一个 NaN(不是一个数字),那么结果就是 NaN

可以把它放在一个函数里:

def my_max(arr):
    try:
        return np.max(arr)
    except d.InvalidOperation:
        return d.Decimal('NaN')

不过,这样看起来不是很酷……


还有一种替代方法……也许……因为 Decimal 可以“解开”一些异常,这样就能返回一个值,而不是抛出一个异常:

# change globally
>>> d.getcontext().traps[d.InvalidOperation] = 0
>>> np.max([d.Decimal('1.2'), d.Decimal('NaN')])
Decimal('NaN')


# use a context manager to change locally:
with d.localcontext() as ctx:
    ctx.traps[d.InvalidOperation] = 0
    np.max([d.Decimal('1.2'), d.Decimal('NaN')])

撰写回答