为什么Python中map()和multiprocessing.Pool.map()的结果不同?
我遇到了一个奇怪的问题。我有一个特定格式的文件:
START
1
2
STOP
lllllllll
START
3
5
6
STOP
我想要读取从 START
到 STOP
之间的内容,把这些内容当作一个个块来处理,并用 my_f
来处理每个块。
def block_generator(file):
with open(file) as lines:
for line in lines:
if line == 'START':
block=itertools.takewhile(lambda x:x!='STOP',lines)
yield block
在我的主函数里,我尝试用 map()
来完成这个工作,结果是成功的。
blocks=block_generator(file)
map(my_f,blocks)
这样确实能得到我想要的结果。但是当我尝试用 multiprocessing.Pool.map()
时,出现了一个错误,提示 takewhile()
需要两个参数,但我给了它零个。
blocks=block_generator(file)
p=multiprocessing.Pool(4)
p.map(my_f,blocks)
这是个bug吗?
- 这个文件有超过1000000个块,每个块少于100行。
- 我接受来自untubu的回答。
- 不过我可能会简单地把文件拆分,然后用我原来的脚本的多个实例来处理这些块,而不使用多进程,最后再把结果合在一起。这样只要脚本在小文件上能正常工作,就不会出错。
2 个回答
简单来说,当你像现在这样遍历一个文件时,每次从文件中读取一行,文件指针就会向前移动一行。
所以,当你执行
block=itertools.takewhile(lambda x:x!='STOP',lines)
每次由 takewhile
返回的迭代器从 lines
中获取新项目时,文件指针就会移动。
在 for
循环中,同时推进一个你已经在循环的迭代器通常是不好的做法。不过,每次遇到 yield
时,for
循环会暂时暂停,而 map
会在继续 for
循环之前先把 takewhile
用完,所以你得到了想要的效果。
当你同时运行 for
循环和 takewhile
时,文件指针会快速移动到文件末尾,这样就会出错。
试试这个,应该比把 takewhile
包裹在 list
中要快:
from contextlib import closing
from itertools import repeat
def block_generator(filename):
with open(filename) as infile:
for pos in (infile.tell() for line in infile if line == 'START'):
yield pos
def my_f_wrapper(pos, filename):
with open(filename) as infile:
infile.seek(pos)
block=itertools.takewhile(lambda x:x!='STOP', infile)
my_f(block)
blocks = block_generator(filename)
p.imap(my_f_wrapper, blocks, repeat(filename))
基本上,你希望每个 my_f
独立操作文件,所以你需要为每个 my_f
单独打开文件。
我想不出一种方法可以不让文件被遍历两次,一次是通过 for
循环,另一次是通过所有的 takewhile
,同时又能并行处理文件。在你最初的版本中,takewhile
推进了 for
循环的文件指针,所以效率很高。
如果你不是在遍历行,而是字节的话,我会推荐使用 mmap,但如果你在处理文本行,这样会让事情变得复杂很多。
编辑: 另一种选择是让 block_generator
遍历文件,找到所有 START
和 STOP
的位置,然后成对地传递给包装器。这样,包装器就不需要将行与 STOP
进行比较,只需在文件上使用 tell()
确保它不在 STOP
处。我不确定这样是否会更快。
这样说吧:
import itertools
def grouper(n, iterable, fillvalue=None):
# Source: http://docs.python.org/library/itertools.html#recipes
"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
return itertools.izip_longest(*[iter(iterable)]*n,fillvalue=fillvalue)
def block_generator(file):
with open(file) as lines:
for line in lines:
if line == 'START':
block=list(itertools.takewhile(lambda x:x!='STOP',lines))
yield block
blocks=block_generator(file)
p=multiprocessing.Pool(4)
for chunk in grouper(100,blocks,fillvalue=''):
p.map(my_f,chunk)
使用 grouper
可以限制 p.map
处理文件的数量。这样就不需要一次性把整个文件都读到内存里(放到任务队列中)。
我之前提到,当你调用 p.map(func, iterator)
时,整个迭代器会立刻被消耗掉,以填满任务队列。然后,池中的工作线程会从队列中获取任务,并同时处理这些工作。
如果你查看 pool.py 的内容,跟踪一下定义,你会看到 _handle_tasks
线程会从 self._taskqueue
中获取项目,并一次性列举出来:
for i, task in enumerate(taskseq):
...
put(task)
结论是,传递给 p.map
的迭代器会被一次性消耗掉。并不是等一个任务完成后再从队列中获取下一个任务。
为了进一步证明这一点,如果你运行这个:
演示代码:
import multiprocessing as mp
import time
import logging
def foo(x):
time.sleep(1)
return x*x
def blocks():
for x in range(1000):
if x%100==0:
logger.info('Got here')
yield x
logger=mp.log_to_stderr(logging.DEBUG)
logger.setLevel(logging.DEBUG)
pool=mp.Pool()
print pool.map(foo, blocks())
你会看到 Got here
的消息几乎立刻打印了10次,然后因为 foo
中的 time.sleep(1)
调用而出现了很长的暂停。这明显表明,迭代器在池中的进程完成任务之前就已经完全被消耗掉了。