我想通过Spark MLlib的决策树获得关于生成模型的每个节点的更详细的信息。最接近我使用API的是print(model.toDebugString())
,它返回如下内容(取自PySpark文档)
DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.0)
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
如何修改MLlib源代码以获得每个节点的杂质和深度?(如果有必要,如何在PySpark中调用新的Scala函数?)在
不幸的是,我找不到任何方法直接访问PySpark或Spark(scalaapi)中的节点。但是有一种方法可以从根节点开始遍历到不同的节点。在
(我在这里刚刚提到了杂质,但是对于深度,人们可以很容易地用
impurity
代替subtreeDepth
。)假设决策树模型实例是
dt
:火花塞
现在,如果我们看看适用于
^{pr2}$root
的方法:我们可以:
这可能会深入到树的深处,例如:
因为在应用
leftNode()
或rightNode()
之后,我们得到一个option
,应用get
或getOrElseis necessary to get to the desired
节点类型。在如果你想知道我是怎么得到这些奇怪的方法的,我得承认,我作弊了!!,即我首先研究了Scala API:
火花
以下几行与上面的完全等价,假设
dt
相同,则给出相同的结果:我们可以:
这可能会深入到树的深处,例如:
我将通过描述如何使用pyspark2.4.3来补充@mostofmoly的答案。在
根节点
给定一个经过训练的决策树模型,下面是如何获取其根节点的方法:
杂质
我们可以从根节点走下树来得到杂质。它的pre-order transversal可以这样做:
^{pr2}$示例
相关问题 更多 >
编程相关推荐