如何在Spark决策树模型中获取节点信息

2024-04-20 10:40:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我想通过Spark MLlib的决策树获得关于生成模型的每个节点的更详细的信息。最接近我使用API的是print(model.toDebugString()),它返回如下内容(取自PySpark文档)

  DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0

如何修改MLlib源代码以获得每个节点的杂质和深度?(如果有必要,如何在PySpark中调用新的Scala函数?)在


Tags: 文档模型api信息决策树内容model节点
2条回答

不幸的是,我找不到任何方法直接访问PySpark或Spark(scalaapi)中的节点。但是有一种方法可以从根节点开始遍历到不同的节点。在

(我在这里刚刚提到了杂质,但是对于深度,人们可以很容易地用impurity代替subtreeDepth。)

假设决策树模型实例是dt

火花塞

root = dt.call("topNode")
root.impurity() # gives the impurity of the root node

现在,如果我们看看适用于root的方法:

^{pr2}$

我们可以:

root.leftNode().get().impurity()

这可能会深入到树的深处,例如:

root.leftNode().get().rightNode().get().impurity()

因为在应用leftNode()rightNode()之后,我们得到一个option,应用get或getOrElseis necessary to get to the desired节点类型。在

如果你想知道我是怎么得到这些奇怪的方法的,我得承认,我作弊了!!,即我首先研究了Scala API:

火花

以下几行与上面的完全等价,假设dt相同,则给出相同的结果:

val root = dt.topNode
root.impurity

我们可以:

root.leftNode.get.impurity

这可能会深入到树的深处,例如:

root.leftNode.get.rightNode.get.impurity

我将通过描述如何使用pyspark2.4.3来补充@mostofmoly的答案。在

根节点

给定一个经过训练的决策树模型,下面是如何获取其根节点的方法:

def _get_root_node(tree: DecisionTreeClassificationModel):
    return tree._call_java('rootNode')

杂质

我们可以从根节点走下树来得到杂质。它的pre-order transversal可以这样做:

^{pr2}$

示例

In [1]: print(tree.toDebugString)
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_f90ba6dbb0fe) of depth 3 with 7 nodes
  If (feature 0 <= 6.5)
   If (feature 0 <= 3.5)
    Predict: 1.0
   Else (feature 0 > 3.5)
    If (feature 0 <= 5.0)
     Predict: 0.0
    Else (feature 0 > 5.0)
     Predict: 1.0
  Else (feature 0 > 6.5)
   Predict: 0.0


In [2]: cat.get_impurities(tree)
Out[2]: [0.4444444444444444, 0.5, 0.5]

相关问题 更多 >