PySpark:如何高效地读取列位置不同的多个CSV文件
我正在尝试用Spark高效地读取一个文件夹里的多个CSV文件。可惜的是,我还没找到比逐个读取每个文件更好的方法,这样做非常耗时间。
根据我的理解,读取多个CSV文件最有效的方法是使用*
,像这样:
df = spark.read.format('csv') \
.option('header', 'true') \
.load('/path/to/csv/folder/*.csv')
不过,虽然这种方法很快,但它并不会根据列名来合并数据,而是按照列的顺序来处理。比如,如果这个文件夹里有以下两个CSV文件:
1.csv
:
A | B | C |
---|---|---|
1 | 2 | 5 |
3 | 4 | 6 |
2.csv
:
A | C |
---|---|
7 | 8 |
那么之前的方法会把它们合并成这样:
df
:
A | B | C |
---|---|---|
1 | 2 | 5 |
3 | 4 | 6 |
7 | 8 | NULL |
这显然是错误的,因为最后一行应该是7|NULL|8
。
我通过逐个读取每个文件,然后使用unionByName
方法,并把allowMissingColumns
参数设置为True
来解决了这个问题,像这样:
dfs = []
for filename in list_file_names('/path/to/csv/folder'):
dfs.append(spark.read.format('csv') \
.option('header', 'true') \
.load('/path/to/csv/folder/{filename}')
)
union_df = dfs[0]
for df in dfs[1:]:
union_df = union_df.unionByName(df, allowMissingColumns=True)
这样做效果很好,但因为我逐个读取文件,所以速度慢得多。在同一台机器上,处理100个小CSV文件时,第一种(但错误的)方法大约需要6秒,而第二种方法则需要16秒。
所以我的问题是,能不能像第一种方法那样,只进行一次读取操作,就在PySpark中实现相同的结果呢?
2 个回答
谢谢你,@parisni!
根据你的建议,我实现了一个更快的解决方案,具体如下:
首先,我定义了一个函数,用来读取CSV文件的表头,并确定需要进行的最少次数的 unionByName
操作。
def read_headers():
"""
Read headers from the CSV files and return them as a dict, where the key is a tuple of the indices
of the columns of interest, and the value is the list of file paths
Note: To minimize the number of unionByName operations, we only look at the indexes of the columns
of interest in the subsequent queries
"""
headers = {}
for filename in list_file_names(DATASET_PATH):
path = f'{DATASET_PATH}/{filename}'
header = spark.read.option("header", "true").csv(path).columns
# Another way to go is with smart_open as follows:
# header = smart_open.smart_open(path).readline().strip()
key = tuple(header.index(c) for c in COLUMNS_OF_INTEREST)
headers[key] = headers.get(key, []) + [filename]
return headers
然后我运行 unionByName
,具体代码如下:
headers = read_headers()
dfs = []
for filenames in headers.values():
dfs.append(spark.read.format('csv')
.option('header', 'true')
.load(filenames))
union_df = reduce(lambda x, y: x.unionByName(y, allowMissingColumns=True), dfs)
如果有人好奇的话, list_file_names
函数的定义如下:
def list_file_names(directory_path):
"""
List all files in a given directory on the hdfs filesystem
"""
file_status_objects = sc._gateway.jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration()).listStatus(
sc._jvm.org.apache.hadoop.fs.Path(directory_path)
)
return sorted([str(file.getPath().getName()) for file in file_status_objects])
我能否通过一次读取操作在PySpark中实现相同的结果?
很遗憾,由于合并模式的限制,你不能一次性使用Spark的数据源API。
不过,你可以优化你的合并方法,先读取每个文件的表头,然后按照CSV的类别进行分组,最后再将每个类别的文件合并在一起。
获取与文件路径相关的第一行数据,可以用纯Python实现,比如使用boto。
然后,可以通过用逗号分隔的路径列表一次性读取多个文件。
如果你的CSV类别不多,采用两步走的方法会比逐个合并每个文件快得多。