Pyspark:如何实现dataframe descripe()和summary()呢

2024-04-26 07:39:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我想知道df.describe()df.summary()是如何实现的

https://spark.apache.org/docs/latest/api/python/_modules/pyspark/sql/dataframe.html#DataFrame.summary

def summary(self, *statistics):
    if len(statistics) == 1 and isinstance(statistics[0], list):
        statistics = statistics[0]
    jdf = self._jdf.summary(self._jseq(statistics))
    return DataFrame(jdf, self.sql_ctx)

我不太熟悉python中的OO,我有点困惑。分位数和其他统计在哪里实现?


Tags: httpsorgselfapidocsdataframedfsql
1条回答
网友
1楼 · 发布于 2024-04-26 07:39:52
  • jdf是对通过Py4j访问的JavaDataset对象的引用
  • Python代码调用其summary方法:

    jdf = self._jdf.summary(self._jseq(statistics))
    
  • ^{} calls ^{} method

    def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)
    
  • 像这样的is implemented

    def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
    
    
      val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
      val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
    
    
      val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
        try {
          p.stripSuffix("%").toDouble / 100.0
        } catch {
          case e: NumberFormatException =>
            throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
        }
      }
      require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
    
    
      var percentileIndex = 0
      val statisticFns = selectedStatistics.map { stats =>
        if (stats.endsWith("%")) {
          val index = percentileIndex
          percentileIndex += 1
          (child: Expression) =>
            GetArrayItem(
              new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(),
              Literal(index))
        } else {
          stats.toLowerCase(Locale.ROOT) match {
            case "count" => (child: Expression) => Count(child).toAggregateExpression()
            case "mean" => (child: Expression) => Average(child).toAggregateExpression()
            case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression()
            case "min" => (child: Expression) => Min(child).toAggregateExpression()
            case "max" => (child: Expression) => Max(child).toAggregateExpression()
            case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic")
          }
        }
      }
    
    
      val selectedCols = ds.logicalPlan.output
        .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
    
    
      val aggExprs = statisticFns.flatMap { func =>
        selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
      }
    
    
      // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
      lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
    
    
      // We will have one row for each selected statistic in the result.
      val result = Array.fill[InternalRow](selectedStatistics.length) {
        // each row has the statistic name, and statistic values of each selected column.
        new GenericInternalRow(selectedCols.length + 1)
      }
    
    
      var rowIndex = 0
      while (rowIndex < result.length) {
        val statsName = selectedStatistics(rowIndex)
        result(rowIndex).update(0, UTF8String.fromString(statsName))
        for (colIndex <- selectedCols.indices) {
          val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
          result(rowIndex).update(colIndex + 1, statsValue)
        }
        rowIndex += 1
      }
    
    
      // All columns are string type
      val output = AttributeReference("summary", StringType)() +:
        selectedCols.map(c => AttributeReference(c.name, StringType)())
    
    
      Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
    }
    

相关问题 更多 >