在有许多列的情况下,如何有效地使用join()来减少时间?

2024-04-25 19:42:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个脚本melts()各种dataframes然后join它们。它执行equi join。你知道吗

我不需要任何连接,如果我可以融化在不同的列一次。现在,当本地机器的spark出现堆大小内存问题时,AWS Glue大约需要1个小时来执行此操作。你知道吗

我试着研究melt()join()这两个来自dataframe以及不同的测试用例。你知道吗

我的melt()代码取自another源代码。你知道吗

def melt_df(
        df: DataFrame,
        id_vars: Iterable[str], value_vars: Iterable[str],
        var_name: str = "variable", value_name: str = "value") -> DataFrame:
    """Convert :class:`DataFrame` from wide to long format."""

    # Create array<struct<variable: str, value: ...>>
    _vars_and_vals = array(*(
        struct(lit(c).alias(var_name), col(c).alias(value_name))
        for c in value_vars))
    # Add to the DataFrame and explode
    _tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))

    cols = id_vars + [
        col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
    return _tmp.select(*cols)

至于join:

df = df.join(df_sales, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))
df = df.join(df_vol, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))
df = df.join(df_gr_val, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))

示例数据集:

unnamed_0, unnamed_1, Value1, Value2, Sales1, Sales2, Vol1, Vol2
Markets, Channel, Jan16, Feb16, Jan16, Feb16, Jan16, Feb16
Lucknow, no9, 1,2,3,4,5,6
Delhi, no10,2,3,4,5,6,7

完整代码:

# gr_val_ya
# gr_vol_ya

# Below code is required at the start of each script
import sys

from pyspark.sql.functions import *
from pyspark.sql.functions import (array, col, explode, lit,
                                   monotonically_increasing_id, struct)
from pyspark.sql.types import *


# spark context is already created for you as athena_context.sc, spark session is already
# created for you as athena_context.spark
# include this line in your code, though this eg doesn’t need it : spark = athena_context.spark

# get two input names
# 'S3' & 'S3_1' are the name of the input node that you attached to your processing node
# df = athena_context.input('S3')
# df2 = athena_context.input('S3_1')

# write rest of the code here
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('target_melt').getOrCreate()
last_correct_constant = 11


def sanitize_columns_in_df(df):
    for column in df.columns:
        df = df.withColumnRenamed(column, column.replace("(", "_") .replace(
            ")", "_") .replace(" ", "_") .replace("/", "_").replace(".", "_").lower())
    return df


def get_period_df(df):
    df = df.filter(df["unnamed__0"] == 'Markets')
    return df


def getrows(df, rownums=None):
    return df.rdd.zipWithIndex().filter(
        lambda x: x[1] in rownums).map(lambda x: x[0])


def melt_df(df, id_vars, value_vars, var_name="variable", value_name="value"):
    """Convert :class:`DataFrame` from wide to long format."""

    # Create array<struct<variable: str, value: ...>>
    _vars_and_vals = array(*(
        struct(lit(c).alias(var_name), col(c).alias(value_name))
        for c in value_vars))
    # Add to the DataFrame and explode
    _tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))
    cols = id_vars + [
        col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
    return _tmp.select(*cols)


def remove_unwanted_cols(df, col_name):
    for _col in df.columns:
        if _col.startswith("unnamed__"):
            if int(_col.split("unnamed__")[-1]) > last_correct_constant:
                df = df.drop(_col)
        else:
            # removes the reqd cols, keeps the messed up ones only.
            if not _col.startswith(col_name):
                df = df.drop(_col)
    return df


def make_first_row_head(df):
    period_df = get_period_df(df)
    period_row = getrows(period_df, rownums=[0]).collect()[0]
    print("Period row from df is: ")
    print(period_row)
    period_row_dict = period_row.asDict()
    data_list = []
    schema = []
    for k, v in period_row_dict.items():
        data_list.append(v)
        if isinstance(v, int):
            schema.append(
                StructField(
                    k, IntegerType()
                )
            )
        elif isinstance(v, float):
            schema.append(
                StructField(
                    k, FloatType()
                )
            )
        else:
            schema.append(
                StructField(
                    k, StringType()
                )
            )
    schema = StructType(
        schema
    )
    period_df = spark.createDataFrame([data_list], schema)
    rest_df = df.subtract(period_df)
    header_column = period_df.first()
    for column in rest_df.columns:
        rest_df = rest_df.withColumnRenamed(column, header_column[column])
    return rest_df


def remove_cols_for_join(df, col_name):
    for _col in df.columns:
        if _col != 'period' and (not _col.startswith(col_name)) and _col != 'Markets':
            df = df.drop(_col)
    return df


# for Value column
df = spark.read.orc('rename_columns')
print("Fresh cols are: ")
print(df.columns)
df = remove_unwanted_cols(df, 'value_offtake_000_rs__')
print("DF cols after dropping unwanted are: ")
print(df.columns)
df = make_first_row_head(df)
print("DF columns after making first row as header:")
print(df.columns)
# column headers
table_columns = df.columns
df = melt_df(df, table_columns[:last_correct_constant+1],
             table_columns[last_correct_constant+1:], 'period', 'value_offtake_000_rs___')

# for sales column
df_sales = spark.read.orc('rename_columns')
print("Fresh cols are: ")
print(df_sales.columns)
df_sales = remove_unwanted_cols(df_sales, 'sales_volume__volume_litres__')
print("DF Sales cols after dropping unwanted are: ")
print(df_sales.columns)
df_sales = make_first_row_head(df_sales)
print("DF Sales columns after making first row as header:")
print(df_sales.columns)

table_columns = df_sales.columns
print("Table cols df_sales are: ")
print(table_columns)
df_sales = melt_df(df_sales, table_columns[:last_correct_constant+1],
                   table_columns[last_correct_constant+1:], 'period', 'sales_volume__volume_litres___')

# remove all cols except period and Sales Volume
print(df_sales.columns)
df_sales = remove_cols_for_join(df_sales, 'sales_volume__volume_litres___')
print("After removing cols for join, df_sales' cols are: ")
print(df_sales.columns)

# for gr vol
df_vol = spark.read.orc('rename_columns')
print("Fresh cols are: ")
print(df_vol.columns)
df_vol = remove_unwanted_cols(df_vol, 'gr_vol_ya')
print("DF Vol cols after dropping unwanted are: ")
print(df_vol.columns)
df_vol = make_first_row_head(df_vol)
print("DF Sales columns after making first row as header:")
print(df_vol.columns)

table_columns = df_vol.columns
print("Table cols df_vol are: ")
print(table_columns)
df_vol = melt_df(df_vol, table_columns[:last_correct_constant+1],
                 table_columns[last_correct_constant+1:], 'period', 'gr_vol_ya')

# remove all cols except period and Sales Volume
print(df_vol.columns)
df_vol = remove_cols_for_join(df_vol, 'gr_vol_ya')
print("After removing cols for join, df_vol' cols are: ")
print(df_vol.columns)

# for gr_val_ya
df_gr_val = spark.read.orc('rename_columns')
print("Fresh cols are: ")
print(df_gr_val.columns)
df_gr_val = remove_unwanted_cols(df_gr_val, 'gr_val_ya')
print("DF GrVal cols after dropping unwanted are: ")
print(df_gr_val.columns)
df_gr_val = make_first_row_head(df_gr_val)
print("DF Sales columns after making first row as header:")
print(df_gr_val.columns)

table_columns = df_gr_val.columns
print("Table cols df_vol are: ")
print(table_columns)
df_gr_val = melt_df(df_gr_val, table_columns[:last_correct_constant+1],
                    table_columns[last_correct_constant+1:], 'period', 'gr_val_ya')

# remove all cols except period and Sales Volume
print(df_gr_val.columns)
df_gr_val = remove_cols_for_join(df_gr_val, 'gr_val_ya')
print("After removing cols for join, df_gr_val' cols are: ")
print(df_gr_val.columns)

# sanitize df's columns
df = sanitize_columns_in_df(df)
df_sales = sanitize_columns_in_df(df_sales)
df_vol = sanitize_columns_in_df(df_vol)
df_gr_val = sanitize_columns_in_df(df_gr_val)
print("After sanitizing cols of all 4 dataframes:")
print("df")
print(df.columns)
print("df_sales")
print(df_sales.columns)
print("df vol")
print(df_vol.columns)
print("df gr val")
print(df_gr_val.columns)
df = df.join(df_sales, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))
df = df.join(df_vol, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))
df = df.join(df_gr_val, ["markets", "period"])
df = df.subtract(df.filter(df["markets"] == "Markets"))
df = df.subtract(df.filter(df["markets"] == ''))
df.write.orc('output_rename_columns_6')


Tags: columnsnamedfforcolvalvarsperiod