numpy与多进程队列组合打乱了队列顺序
我正在使用以下模式来进行多进程处理:
for item in data:
inQ.put(item)
for i in xrange(nProcesses):
inQ.put('STOP')
multiprocessing.Process(target=worker, args=(inQ, outQ)).start()
inQ.join()
outQ.put('STOP')
for result in iter(outQ.get, 'STOP'):
# save result
这个方法运行得很好。但是,如果我通过 outQ
发送一个 numpy 数组,'STOP'
就不会出现在 outQ
的末尾,这导致我获取结果的循环提前结束。
这里有一些代码可以复现这个问题。
import multiprocessing
import numpy as np
def worker(inQ, outQ):
for i in iter(inQ.get, 'STOP'):
result = np.random.rand(1,100)
outQ.put(result)
inQ.task_done()
inQ.task_done() # for the 'STOP'
def main():
nProcesses = 8
data = range(1000)
inQ = multiprocessing.JoinableQueue()
outQ = multiprocessing.Queue()
for item in data:
inQ.put(item)
for i in xrange(nProcesses):
inQ.put('STOP')
multiprocessing.Process(target=worker, args=(inQ, outQ)).start()
inQ.join()
print outQ.qsize()
outQ.put('STOP')
cnt = 0
for result in iter(outQ.get, 'STOP'):
cnt += 1
print "got %d items" % cnt
print outQ.qsize()
if __name__ == '__main__':
main()
如果你把 result = np.random.rand(1,100)
替换成类似 result = i*i
的东西,代码就会按预期工作。
这到底发生了什么?我是不是做错了什么根本性的事情?我本来以为在 inQ.join()
之后调用 outQ.put()
就能达到我想要的效果,因为 join()
会阻塞,直到所有进程都完成了所有的 put()
操作。
一个对我有效的解决方法是用 while outQ.qsize() > 0
来进行结果获取循环,这样也能正常工作。但是我听说 qsize()
不是很可靠。它在不同进程运行时就不可靠吗?在完成 inQ.join()
后依赖 qsize()
是否安全呢?
我预计会有人建议使用 multiprocessing.Pool.map()
,但我在用 numpy 数组(ndarrays)时遇到了 pickle 错误。
谢谢你们的关注!
2 个回答
numpy数组可以进行复杂的比较。所以,当你写 a=='STOP' 时,返回的是一个numpy数组,而不是一个布尔值(真或假)。这个numpy数组不能被强制转换成布尔值。在背后,iter(outQ.get, 'STOP') 正在进行这样的比较,并且在尝试把结果转换成布尔值时,可能会把异常处理成假(False)。你需要手动写一个循环,从队列中取出项目,检查这个项目是不是字符串类型,然后再和'STOP'进行比较。
while True:
item = outQ.get()
if isinstance(item, basestring) and item == 'STOP':
break
cnt += 1
检查qsize()也可能很好用,因为在输入队列被合并后,没有其他进程会往队列里添加内容。
既然你知道从 outQ
中会得到多少个项目,另一种解决方法就是明确地等待这些项目的数量:
import multiprocessing as mp
import numpy as np
import Queue
N=100
def worker(inQ, outQ):
while True:
i,item=inQ.get()
result = np.random.rand(1,N)
outQ.put((i,result))
inQ.task_done()
def main():
nProcesses = 8
data = range(N)
inQ = mp.JoinableQueue()
outQ = mp.Queue()
for i,item in enumerate(data):
inQ.put((i,item))
for i in xrange(nProcesses):
proc=mp.Process(target=worker, args=[inQ, outQ])
proc.daemon=True
proc.start()
inQ.join()
cnt=0
for _ in range(N):
result=outQ.get()
print(result)
cnt+=1
print(cnt)
print('got {c} items'.format(c=cnt))
if __name__ == '__main__':
main()