如何使用pyspark中的图形框架查找人员(员工、经理等)的层次结构?

2024-04-26 07:36:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个带有顶点和边的图框架,如下所示。我在jupyter笔记本的pyspark上运行这个

vertices = sqlContext.createDataFrame([
      ("12345", "Alice", "Employee"),
      ("15789", "Bob", "Employee"),
      ("13467", "Charlie", "Manager"),
      ("14890", "David", "Director"),
      ("17737", "Fanny", "CEO")], ["id", "name", "title"])

    edges = sqlContext.createDataFrame([
      ("12345", "13467", "works"),
      ("15789", "13467", "works"),
      ("13467", "14890", "works"),
      ("14890", "17737", "works"),
    ], ["src", "dst", "relationship"])

我需要找到每个emp_id到最高级别(在本例中是CEO)的分层路径。我正在尝试bfs方法,到目前为止,我只成功地获得了一个emp_id的路径。 下面是我的代码

g = GraphFrame(vertices,edges)
result = g.bfs(fromExpr = "id == '12345'", toExpr = "title == 'CEO'", edgeFilter = "relationship == 'works'", maxPathLength = 5)
result.show(5,False)

输出:

+----------------------+-------------------+-----------------------+-------------------+----------------------+-------------------+-----------------+
|from                  |e0                 |v1                     |e1                 |v2                    |e2                 |to               |
+----------------------+-------------------+-----------------------+-------------------+----------------------+-------------------+-----------------+
|[12345,Alice,Employee]|[12345,13467,works]|[13467,Charlie,Manager]|[13467,14890,works]|[14890,David,Director]|[14890,17737,works]|[17737,Fanny,CEO]|
+----------------------+-------------------+-----------------------+-------------------+----------------------+-------------------+-----------------+

我可以将此信息存储在一个变量中,并使用collect()方法提取。我希望循环遍历顶点的所有id,这些顶点有一个到CEO的路径,并将其写入数据帧。如果有人熟悉画框,你能帮我吗?我曾尝试寻找其他解决方案,但没有一个在我的情况下起作用

预期产出:

+-------+--------------------------+
|user_id|path                      |
+-------+--------------------------+
|12345  |12345->13467->14890->17737|
|15789  |15789->13467->14890->17737|
|13467  |13467->14890->17737       |
|14890  |14890->17737              |
|17737  |17737                     |
+-------+--------------------------+

Tags: 路径idemployeemanagerdaviddirectorworks顶点
1条回答
网友
1楼 · 发布于 2024-04-26 07:36:43

根据您的问题调整this answer,并整理该答案的结果以获得所需的输出。请注意,您需要在edges数据框中交换'src'和'dst'以使该答案起作用,但我认为在修改该答案时,可以以原始形式使用edges数据框

from graphframes import GraphFrame
from graphframes.lib import Pregel
import pyspark.sql.functions as F
from pyspark.sql.types import *

vertices = spark.createDataFrame([
      ("12345", "Alice", "Employee"),
      ("15789", "Bob", "Employee"),
      ("13467", "Charlie", "Manager"),
      ("14890", "David", "Director"),
      ("17737", "Fanny", "CEO")], ["id", "name", "title"])

edges = spark.createDataFrame([
      ("12345", "13467", "works"),
      ("15789", "13467", "works"),
      ("13467", "14890", "works"),
      ("14890", "17737", "works"),
    ], ["dst", "src", "relationship"])

g = GraphFrame(vertices,edges)

vertColSchema = StructType()\
      .add("dist", DoubleType())\
      .add("node", StringType())\
      .add("path", ArrayType(StringType(), True))

def vertexProgram(vd, msg):
    if msg == None or vd.__getitem__(0) < msg.__getitem__(0):
        return (vd.__getitem__(0), vd.__getitem__(1), vd.__getitem__(2))
    else:
        return (msg.__getitem__(0), vd.__getitem__(1), msg.__getitem__(2))

vertexProgramUdf = F.udf(vertexProgram, vertColSchema)

def sendMsgToDst(src, dst):
    srcDist = src.__getitem__(0)
    dstDist = dst.__getitem__(0)
    if srcDist < (dstDist - 1):
        return (srcDist + 1, src.__getitem__(1), src.__getitem__(2) + [dst.__getitem__(1)])
    else:
        return None

sendMsgToDstUdf = F.udf(sendMsgToDst, vertColSchema)

def aggMsgs(agg):
    shortest_dist = sorted(agg, key=lambda tup: tup[1])[0]
    return (shortest_dist.__getitem__(0), shortest_dist.__getitem__(1), shortest_dist.__getitem__(2))

aggMsgsUdf = F.udf(aggMsgs, vertColSchema)

result = (
    g.pregel.withVertexColumn(
        colName = "vertCol",

        initialExpr = F.when(
            F.col("id") == 17737,
            F.struct(F.lit(0.0), F.col("id"), F.array(F.col("id")))
        ).otherwise(
            F.struct(F.lit(float("inf")), F.col("id"), F.array(F.lit("")))
        ).cast(vertColSchema),

        updateAfterAggMsgsExpr = vertexProgramUdf(F.col("vertCol"), Pregel.msg())
    )
    .sendMsgToDst(sendMsgToDstUdf(F.col("src.vertCol"), Pregel.dst("vertCol")))
    .aggMsgs(aggMsgsUdf(F.collect_list(Pregel.msg())))
    .setMaxIter(5)    ## This should be greater than the max depth of the graph
    .setCheckpointInterval(1)
    .run()
)

df = result.select("vertCol.node", "vertCol.path").repartition(1)
df.show()
+  -+              +
|node |path                        |
+  -+              +
|12345|[17737, 14890, 13467, 12345]|
|15789|[17737, 14890, 13467, 15789]|
|13467|[17737, 14890, 13467]       |
|14890|[17737, 14890]              |
|17737|[17737]                     |
+  -+              +

final = df.select('node', F.concat_ws('->', F.reverse('path')).alias('path'))
final.show()
+  -+             +
|node |path                      |
+  -+             +
|12345|12345->13467->14890->17737|
|15789|15789->13467->14890->17737|
|13467|13467->14890->17737       |
|14890|14890->17737              |
|17737|17737                     |
+  -+             +

相关问题 更多 >