sklearn的PLSRegression:“ValueError:数组不能包含infs或NaNs”

2024-05-15 02:49:37 发布

您现在位置:Python中文网/ 问答频道 /正文

使用^{}时:

import numpy as np
import sklearn.cross_decomposition

pls2 = sklearn.cross_decomposition.PLSRegression()
xx = np.random.random((5,5))
yy = np.zeros((5,5) ) 

yy[0,:] = [0,1,0,0,0]
yy[1,:] = [0,0,0,1,0]
yy[2,:] = [0,0,0,0,1]
#yy[3,:] = [1,0,0,0,0] # Uncommenting this line solves the issue

pls2.fit(xx, yy)

我得到:

C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:44: RuntimeWarning: invalid value encountered in divide
  x_weights = np.dot(X.T, y_score) / np.dot(y_score.T, y_score)
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:64: RuntimeWarning: invalid value encountered in less
  if np.dot(x_weights_diff.T, x_weights_diff) < tol or Y.shape[1] == 1:
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:67: UserWarning: Maximum number of iterations reached
  warnings.warn('Maximum number of iterations reached')
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:297: RuntimeWarning: invalid value encountered in less
  if np.dot(x_scores.T, x_scores) < np.finfo(np.double).eps:
C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py:275: RuntimeWarning: invalid value encountered in less
  if np.all(np.dot(Yk.T, Yk) < np.finfo(np.double).eps):
Traceback (most recent call last):
  File "C:\svn\hw4\code\test_plsr2.py", line 8, in <module>
    pls2.fit(xx, yy)
  File "C:\Anaconda\lib\site-packages\sklearn\cross_decomposition\pls_.py", line 335, in fit
    linalg.pinv(np.dot(self.x_loadings_.T, self.x_weights_)))
  File "C:\Anaconda\lib\site-packages\scipy\linalg\basic.py", line 889, in pinv
    a = _asarray_validated(a, check_finite=check_finite)
  File "C:\Anaconda\lib\site-packages\scipy\_lib\_util.py", line 135, in _asarray_validated
    a = np.asarray_chkfinite(a)
  File "C:\Anaconda\lib\site-packages\numpy\lib\function_base.py", line 613, in asarray_chkfinite
    "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

有什么问题吗?

我知道scikit-learn GitHub issue #2089,但是由于我使用scikit learn 0.16.1(与Python 2.7.10 x64一起使用),这个问题应该得到解决(GitHub问题中提到的代码片段可以正常工作)。


Tags: inpylibpackagesnplinesiteanaconda
3条回答

请检查传入的值是否为NaN或inf:

np.isnan(xx).any()
np.isnan(yy).any()

np.isinf(xx).any()
np.isinf(yy).any()

如果这些都是真的。删除nan项或inf项。E、 g.您可以使用以下命令将它们设置为0

xx = np.nan_to_num(xx)
yy = np.nan_to_num(yy)

numpy也有可能被输入如此大的正、负和零值,以至于库中深层的方程正在产生零,Nan或Inf。奇怪的是,一种解决方法是发送较小的数字(比如-1到1之间的代表数字)。一种方法是通过标准化,请参见:https://stackoverflow.com/a/36390482/445131

如果这些都不能解决问题,那么您可能正在处理您使用的库中的一个低级错误,或者数据中的某种奇点。创建一个sscce并将其发布到stackoverflow,或者在维护软件的库上创建一个新的bug报告。

我可以复制相同的bug,我通过过滤掉所有的0来消除这个bug

threshold_for_bug = 0.00000001 # could be any value, ex numpy.min
xx[xx < threshold_for_bug] = threshold_for_bug

这样可以消除错误(我从不检查精度差异)

我的系统信息:

numpy-1.11.2
python-3.5
macOS Sierra

这个问题是由scikit learn中的一个错误引起的。我在GitHub上报道过:https://github.com/scikit-learn/scikit-learn/issues/2089#issuecomment-152753095

相关问题 更多 >

    热门问题