h2o python的一个\u hot \u显式参数

2024-04-26 13:02:55 发布

您现在位置:Python中文网/ 问答频道 /正文

在使用pythonh2o库在h2ov3.10中训练模型时,我发现在尝试将one_hot_explicit设置为categorical_encoding参数的选项时出现了一个错误。你知道吗

encoding = "enum"

gbm = H2OGradientBoostingEstimator(
        categorical_encoding = encoding)

gbm.train(x, y,train_h2o_df,test_h2o_df)

工作正常,模型使用enum分类编码,但当:

encoding = "one_hot_explicit"

或者

encoding = "OneHotExplicit"

出现以下错误:

gbm Model Build progress: | (failed)
....
OSError: Job with key $03017f00000132d4ffffffff$_bde8fcb4777df7e0be1199bf590a47f9 failed with an exception: java.lang.AssertionError
stacktrace: 
java.lang.AssertionError
at hex.ModelBuilder.init(ModelBuilder.java:958)
at hex.tree.SharedTree.init(SharedTree.java:78)
at hex.tree.gbm.GBM.init(GBM.java:57)
at hex.tree.SharedTree$Driver.computeImpl(SharedTree.java:159)
at hex.ModelBuilder$Driver.compute2(ModelBuilder.java:169)
at water.H2O$H2OCountedCompleter.compute(H2O.java:1203)
at jsr166y.CountedCompleter.exec(CountedCompleter.java:468)
at jsr166y.ForkJoinTask.doExec(ForkJoinTask.java:263)
at jsr166y.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:974)
at jsr166y.ForkJoinPool.runWorker(ForkJoinPool.java:1477)
at jsr166y.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:104)

我是否缺少依赖项,或者这是一个bug?你知道吗


Tags: 模型treeinitjavaoneatencodinghex
1条回答
网友
1楼 · 发布于 2024-04-26 13:02:55

尽管您可能希望更新到最新的稳定版本H2O,但您的编码选择应该是可行的。如果可以的话,你可以试着找出你以前的代码和下面的例子之间的区别。你知道吗

import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator
h2o.init()

# import the airlines dataset:
# This dataset is used to classify whether a flight will be delayed 'YES' or not "NO"
# original data can be found at http://www.transtats.bts.gov/
airlines= h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/airlines/allyears2k_headers.zip")

# convert columns to factors
airlines["Year"]= airlines["Year"].asfactor()
airlines["Month"]= airlines["Month"].asfactor()
airlines["DayOfWeek"] = airlines["DayOfWeek"].asfactor()

# set the predictor names and the response column name
predictors = ["Origin", "Dest", "Year", "DayOfWeek", "Month", "Distance"]
response = "IsDepDelayed"

# split into train and validation sets
train, valid= airlines.split_frame(ratios = [.8], seed = 1234)

# try using the `categorical_encoding` parameter:
encoding = "one_hot_explicit"

# initialize the estimator
airlines_gbm = H2OGradientBoostingEstimator(categorical_encoding = encoding, seed =1234)

# then train the model
airlines_gbm.train(x = predictors, y = response, training_frame = train, validation_frame = valid)

# print the auc for the validation set
airlines_gbm.auc(valid=True)

相关问题 更多 >