numpy与多进程队列组合打乱了队列顺序

4 投票
2 回答
2156 浏览
提问于 2025-04-16 13:48

我正在使用以下模式来进行多进程处理:

    for item in data:
        inQ.put(item)

    for i in xrange(nProcesses):
        inQ.put('STOP')
        multiprocessing.Process(target=worker, args=(inQ, outQ)).start()

    inQ.join()
    outQ.put('STOP')

    for result in iter(outQ.get, 'STOP'):
        # save result

这个方法运行得很好。但是,如果我通过 outQ 发送一个 numpy 数组,'STOP' 就不会出现在 outQ 的末尾,这导致我获取结果的循环提前结束。

这里有一些代码可以复现这个问题。

import multiprocessing
import numpy as np

def worker(inQ, outQ):
    for i in iter(inQ.get, 'STOP'):
        result = np.random.rand(1,100)
        outQ.put(result)
        inQ.task_done()
    inQ.task_done() # for the 'STOP'

def main():
    nProcesses = 8
    data = range(1000)

    inQ = multiprocessing.JoinableQueue()
    outQ = multiprocessing.Queue()
    for item in data:
        inQ.put(item)

    for i in xrange(nProcesses):
        inQ.put('STOP')
        multiprocessing.Process(target=worker, args=(inQ, outQ)).start()

    inQ.join()
    print outQ.qsize()
    outQ.put('STOP')

    cnt = 0
    for result in iter(outQ.get, 'STOP'):
        cnt += 1
    print "got %d items" % cnt
    print outQ.qsize()

if __name__ == '__main__':
    main()

如果你把 result = np.random.rand(1,100) 替换成类似 result = i*i 的东西,代码就会按预期工作。

这到底发生了什么?我是不是做错了什么根本性的事情?我本来以为在 inQ.join() 之后调用 outQ.put() 就能达到我想要的效果,因为 join() 会阻塞,直到所有进程都完成了所有的 put() 操作。

一个对我有效的解决方法是用 while outQ.qsize() > 0 来进行结果获取循环,这样也能正常工作。但是我听说 qsize() 不是很可靠。它在不同进程运行时就不可靠吗?在完成 inQ.join() 后依赖 qsize() 是否安全呢?

我预计会有人建议使用 multiprocessing.Pool.map(),但我在用 numpy 数组(ndarrays)时遇到了 pickle 错误。

谢谢你们的关注!

2 个回答

2

numpy数组可以进行复杂的比较。所以,当你写 a=='STOP' 时,返回的是一个numpy数组,而不是一个布尔值(真或假)。这个numpy数组不能被强制转换成布尔值。在背后,iter(outQ.get, 'STOP') 正在进行这样的比较,并且在尝试把结果转换成布尔值时,可能会把异常处理成假(False)。你需要手动写一个循环,从队列中取出项目,检查这个项目是不是字符串类型,然后再和'STOP'进行比较。

while True:
    item = outQ.get()
    if isinstance(item, basestring) and item == 'STOP':
        break
    cnt += 1

检查qsize()也可能很好用,因为在输入队列被合并后,没有其他进程会往队列里添加内容。

2

既然你知道从 outQ 中会得到多少个项目,另一种解决方法就是明确地等待这些项目的数量:

import multiprocessing as mp
import numpy as np
import Queue

N=100

def worker(inQ, outQ):
    while True:
        i,item=inQ.get()
        result = np.random.rand(1,N)
        outQ.put((i,result))
        inQ.task_done()

def main():
    nProcesses = 8
    data = range(N)
    inQ = mp.JoinableQueue()
    outQ = mp.Queue()    

    for i,item in enumerate(data):
        inQ.put((i,item))

    for i in xrange(nProcesses):
        proc=mp.Process(target=worker, args=[inQ, outQ])
        proc.daemon=True
        proc.start()

    inQ.join()
    cnt=0
    for _ in range(N):
        result=outQ.get()
        print(result)
        cnt+=1
        print(cnt)      
    print('got {c} items'.format(c=cnt))

if __name__ == '__main__':
    main()

撰写回答