我试图用以下代码实现soft-max(out_vec
是一个numpy
浮点向量):
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
但是,由于np.exp(out_vec)
,出现溢出错误。因此,我(手动)检查了np.exp()
的上限是多少,发现np.exp(709)
是一个数字,但np.exp(710)
被认为是np.inf
。因此,为了避免溢出错误,我修改了代码如下:
out_vec[out_vec > 709] = 709 #prevent np.exp overflow
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
现在,我得到一个不同的错误:
RuntimeWarning: invalid value encountered in greater out_vec[out_vec > 709] = 709
我加的那行怎么了?我查找了这个特定的错误,发现的只是人们关于如何忽略错误的建议。忽略这个错误对我没有帮助,因为每次我的代码遇到这个错误时,它都不会给出通常的结果。
在我的例子中,在比较之前调用此函数时没有显示警告(我比较了NaN值)
在IMO中,更好的方法是使用指数和的更稳定的数值实现。
您的问题是由
NaN
或Inf
数组中的元素引起的。您可以使用以下代码来避免此问题:或者可以使用以下代码将
NaN
值保留在数组中:相关问题 更多 >
编程相关推荐