确定线程池何时完成队列处理
我正在尝试实现一个线程池,用来处理一个任务队列,使用的是ThreadPool
和Queue
。一开始有一个初始的任务队列,然后每个任务也可能会往这个队列里添加更多的任务。问题是我不知道怎么在队列为空且线程池完成处理之前阻塞,但又能检查队列并提交任何新任务到线程池,这些新任务是被添加到队列里的。我不能简单地调用ThreadPool.join()
,因为我需要保持线程池开放,以便接收新任务。
举个例子:
from multiprocessing.pool import ThreadPool
from Queue import Queue
from random import random
import time
import threading
queue = Queue()
pool = ThreadPool()
stdout_lock = threading.Lock()
def foobar_task():
with stdout_lock: print "task called"
if random() > .25:
with stdout_lock: print "task appended to queue"
queue.append(foobar_task)
time.sleep(1)
# set up initial queue
for n in range(5):
queue.put(foobar_task)
# run the thread pool
while not queue.empty():
task = queue.get()
pool.apply_async(task)
with stdout_lock: print "pool is closed"
pool.close()
pool.join()
这段代码的输出是:
pool is closed
task called
task appended to queue
task called
task appended to queue
task called
task appended to queue
task called
task appended to queue
task called
task appended to queue
这段代码在foobar_tasks
还没添加到队列之前就退出了循环,所以添加的任务从来没有被提交到线程池。我找不到任何方法来判断线程池是否还有活跃的工作线程。我尝试了以下方法:
while not queue.empty() or any(worker.is_alive() for worker in pool._pool):
if not queue.empty():
task = queue.get()
pool.apply_async(task)
else:
with stdout_lock: print "waiting for worker threads to complete..."
time.sleep(1)
但是似乎worker.is_alive()
总是返回真,所以这就进入了一个无限循环。
有没有更好的方法来解决这个问题呢?
1 个回答
2
- 在每个任务处理完后,调用 queue.task_done。
- 然后你可以调用 queue.join(),这样主线程会被阻塞,直到所有任务都完成。
- 要结束工作线程,可以在队列中放一个哨兵(比如
None
),然后让foobar_task
在收到这个哨兵时跳出while-loop
。 - 我觉得用
threading.Thread
来实现这个比用ThreadPool
更简单。
import random
import time
import threading
import logging
import Queue
logger=logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
sentinel=None
queue = Queue.Queue()
num_threads = 5
def foobar_task(queue):
while True:
n = queue.get()
logger.info('task called: {n}'.format(n=n))
if n is sentinel: break
n=random.random()
if n > .25:
logger.info("task appended to queue")
queue.put(n)
queue.task_done()
# set up initial queue
for i in range(num_threads):
queue.put(i)
threads=[threading.Thread(target=foobar_task,args=(queue,))
for n in range(num_threads)]
for t in threads:
t.start()
queue.join()
for i in range(num_threads):
queue.put(sentinel)
for t in threads:
t.join()
logger.info("threads are closed")