pyspark中的mapPartitions函数是如何工作的?
我正在尝试用Python学习Spark(也就是Pyspark)。我想了解一下mapPartitions
这个函数是怎么工作的。它接受什么输入,输出又是什么。我在网上找不到合适的例子。假设我有一个包含列表的RDD对象,如下所示。
[ [1, 2, 3], [3, 2, 4], [5, 2, 7] ]
我想从所有列表中删除元素2,我该如何使用mapPartitions
来实现这个目标呢?
4 个回答
-1
def func(l):
for i in l:
yield i+"ajbf"
mylist=['madhu','sdgs','sjhf','mad']
rdd=sc.parallelize(mylist)
t=rdd.mapPartitions(func)
for i in t.collect():
print(i)
for i in t.collect():
print(i)
在上面的代码中,我能够从第二个 for..in 循环中获取数据。根据生成器的说法,一旦它遍历完循环,就不应该再显示值了。
1
需要一个最终的迭代
def filter_out_2(partition):
for element in partition:
sec_iterator = []
for i in element:
if i!= 2:
sec_iterator.append(i)
yield sec_iterator
filtered_lists = data.mapPartitions(filter_out_2)
for i in filtered_lists.collect(): print(i)
30
使用生成器函数和 yield
语法来配合 mapPartitions 会更简单:
def filter_out_2(partition):
for element in partition:
if element != 2:
yield element
filtered_lists = data.mapPartitions(filter_out_2)
41
mapPartition
可以理解为对数据分区进行的映射操作,而不是对分区内的每个元素进行操作。它的输入是当前的分区集合,输出则是另一个分区集合。
你传给 map
操作的函数必须处理你 RDD 中的单个元素。
而你传给 mapPartition
的函数则需要处理一个可迭代的 RDD 类型集合,并返回一个可迭代的其他类型或相同类型的集合。
在你的情况下,你可能只想做类似这样的操作:
def filter_out_2(line):
return [x for x in line if x != 2]
filtered_lists = data.map(filterOut2)
如果你想使用 mapPartition
,那么可以这样写:
def filter_out_2_from_partition(list_of_lists):
final_iterator = []
for sub_list in list_of_lists:
final_iterator.append( [x for x in sub_list if x != 2])
return iter(final_iterator)
filtered_lists = data.mapPartition(filterOut2FromPartion)