Numpy max函数在使用包含NaN的Decimal值时失败
当我们使用浮点数(float)时,一切都很顺利。
>>> import numpy as np
>>> np.max(1.2, np.nan)
>>> nan
但是,当我们使用十进制数(Decimal)时...
>>> import numpy as np
>>> import decimal as d
>>> np.max([d.Decimal('1.2'), d.Decimal('NaN')])
>>> InvalidOperation: comparison involving NaN
有没有办法让带有NaN(不是一个数字)的十进制数好好地工作呢?
注意:
- Python 2.7
- Numpy 1.6.2
2 个回答
1
你可以把这个列表转换成一个NumPy数组,数据类型设置为float
(浮点数)。这样,所有的NumPy函数都能正常使用了:
import numpy as np
import decimal as d
print np.max(np.array([0, 1, d.Decimal('nan')], dtype='float'))
print np.nanmax(np.array([0, 1, d.Decimal('nan')], dtype='float'))
输出结果:
nan
1.0
4
嗯……如果里面有至少一个 NaN(不是一个数字),那么结果就是 NaN。
可以把它放在一个函数里:
def my_max(arr):
try:
return np.max(arr)
except d.InvalidOperation:
return d.Decimal('NaN')
不过,这样看起来不是很酷……
还有一种替代方法……也许……因为
Decimal
可以“解开”一些异常,这样就能返回一个值,而不是抛出一个异常:
# change globally
>>> d.getcontext().traps[d.InvalidOperation] = 0
>>> np.max([d.Decimal('1.2'), d.Decimal('NaN')])
Decimal('NaN')
# use a context manager to change locally:
with d.localcontext() as ctx:
ctx.traps[d.InvalidOperation] = 0
np.max([d.Decimal('1.2'), d.Decimal('NaN')])