fit函数中出现错误:ValueError:包含多个元素的数组的真值不明确。使用a.any()或a.all()

2024-04-19 15:02:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用GridSerachCV调整一个随机林分类器。 我有这个代码,但是标题中出现了错误,我不知道为什么。我想可能是你们列车上的某个东西,是一个numpy阵列。 这是我的密码:

def tun_RF(model , X_train , Y_train):
    n_estimators = [10 , 50 , 90 , 130 , 170 ]  
    min_samples_split = [np.linspace(1, 200 , 10, dtype=int)]
    random_grid = {'n_estimators' : n_estimators
                    , 'min_samples_split': min_samples_split}
    grid_search = GridSearchCV(estimator = model , param_grid = random_grid , cv = 3 , n_jobs = -1 , verbose = 2)
    grid_search.fit(X_train , Y_train)
    return grid_search.best_estimator_

我得到的是:

Rimozione feature non utilizzabili
[[130  82  17   9]
 [ 62 339 113  50]
 [ 19 129 175 165]
 [  5  39 148 342]]
              precision    recall  f1-score   support

           0       0.60      0.55      0.57       238
           1       0.58      0.60      0.59       564
           2       0.39      0.36      0.37       488
           3       0.60      0.64      0.62       534

    accuracy                           0.54      1824
   macro avg       0.54      0.54      0.54      1824
weighted avg       0.54      0.54      0.54      1824

Confusion matrix, without normalization
[[130  82  17   9]
 [ 62 339 113  50]
 [ 19 129 175 165]
 [  5  39 148 342]]
Fitting 3 folds for each of 5 candidates, totalling 15 fits
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_split.py:657: Warning: The least populated class in y has only 1 members, which is too few. The minimum number of members in any class cannot be less than n_splits=3.
  % (min_groups, self.n_splits)), Warning)
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=10 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=10 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=90 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=90 
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
[Parallel(n_jobs=-1)]: Done   8 out of  15 | elapsed:    2.1s remaining:    1.8s
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=90 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=130 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=50 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=50 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=50 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=10 
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=130 
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=130 
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=170 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=170 
[CV] min_samples_split=[  1  23  45  67  89 111 133 155 177 200], n_estimators=170 
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:528: FutureWarning: From version 0.22, errors during fit will result in a cross validation score of NaN by default. Use error_score='raise' if you want an exception raised or error_score=np.nan to adopt the behavior from version 0.22.
  FutureWarning)
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 418, in _process_worker
    r = call_item()
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 272, in __call__
    return self.fn(*self.args, **self.kwargs)
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 567, in __call__
    return self.func(*args, **kwargs)
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 225, in __call__
    for func, args, kwargs in self.items]
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 225, in <listcomp>
    for func, args, kwargs in self.items]
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 514, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/ensemble/forest.py", line 330, in fit
    for i, t in enumerate(trees))
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 934, in __call__
    self.retrieve()
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 833, in retrieve
    self._output.extend(job.get(timeout=self.timeout))
  File "/usr/lib/python3.7/multiprocessing/pool.py", line 657, in get
    raise self._value
  File "/usr/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 567, in __call__
    return self.func(*args, **kwargs)
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 225, in __call__
    for func, args, kwargs in self.items]
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 225, in <listcomp>
    for func, args, kwargs in self.items]
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/ensemble/forest.py", line 118, in _parallel_build_trees
    tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/tree/tree.py", line 816, in fit
    X_idx_sorted=X_idx_sorted)
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/tree/tree.py", line 211, in fit
    if not 0. < self.min_samples_split <= 1.:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "main_classification.py", line 111, in <module>
    main()
  File "main_classification.py", line 107, in main
    y_pred=RFClassifier(X_train_mean,X_test_mean,Y_train,Y_test)
  File "main_classification.py", line 32, in RFClassifier
    best_params =tun.tun_RF(classifier , X_train , y_train)
  File "/home/andrea/gruppo3/API/scripts_init/modules_and_main/tuning_classifiers.py", line 11, in tun_RF
    grid_search.fit(X_train , Y_train)
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_search.py", line 687, in fit
    self._run_search(evaluate_candidates)
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_search.py", line 1148, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "/home/andrea/.local/lib/python3.7/site-packages/sklearn/model_selection/_search.py", line 666, in evaluate_candidates
    cv.split(X, y, groups)))
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 934, in __call__
    self.retrieve()
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/parallel.py", line 833, in retrieve
    self._output.extend(job.get(timeout=self.timeout))
  File "/home/andrea/.local/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 521, in wrap_future_result
    return future.result(timeout=timeout)
  File "/usr/lib/python3.7/concurrent/futures/_base.py", line 432, in result
    return self.__get_result()
  File "/usr/lib/python3.7/concurrent/futures/_base.py", line 384, in __get_result
    raise self._exception
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Xu列车是熊猫数据帧(7295,19)。YU列车是一个numpy阵列(7295)


Tags: inpyselfhomelibpackageslocalline
1条回答
网友
1楼 · 发布于 2024-04-19 15:02:51

在@Jeppe的帮助下,我解决了这个问题。 问题是linspace返回一个列表,而GridSearchCV需要数组。另一个问题是,从1开始的拆分没有任何意义(解决方案是一棵树的叶子和y一样可以预测),所以我做了如下工作:min_samples_split= np.arange(start = 10 , stop = 200 , step=10 , dtype=int) 现在可以了!你知道吗

相关问题 更多 >