如何从多个进程中递增共享计数器?

102 投票
8 回答
88418 浏览
提问于 2025-04-15 18:06

我在使用 multiprocessing 模块时遇到了一些问题。我正在用它的 Pool 来创建多个工作者,同时处理很多文件。我希望每处理完一个文件,就能更新一个计数器,这样我就能知道还有多少文件需要处理。下面是一个示例代码:

import os
import multiprocessing

counter = 0


def analyze(file):
    # Analyze the file.
    global counter
    counter += 1
    print counter


if __name__ == '__main__':
    files = os.listdir('/some/directory')
    pool = multiprocessing.Pool(4)
    pool.map(analyze, files)

我找不到解决办法。

8 个回答

18

这是一个非常简单的例子,改编自jkp的回答:

from multiprocessing import Pool, Value
from time import sleep

counter = Value('i', 0)
def f(x):
    global counter
    with counter.get_lock():
        counter.value += 1
    print("counter.value:", counter.value)
    sleep(1)
    return x

with Pool(4) as p:
    r = p.map(f, range(1000*1000))
50

没有竞争条件错误的计数器类:

class Counter(object):
    def __init__(self):
        self.val = multiprocessing.Value('i', 0)

    def increment(self, n=1):
        with self.val.get_lock():
            self.val.value += n

    @property
    def value(self):
        return self.val.value
93

问题在于,counter 这个变量在你的进程之间并没有共享:每个独立的进程都在创建自己的本地实例,并在增加这个实例的值。

你可以查看 文档中的这一部分,里面有一些可以用来在进程之间共享状态的技巧。在你的情况下,你可能想要在你的工作进程之间共享一个 Value 实例。

这里有一个你例子的可运行版本(带有一些虚拟输入数据)。请注意,它使用了全局变量,而我在实际操作中真的建议尽量避免使用全局变量:

from multiprocessing import Pool, Value
from time import sleep

counter = None

def init(args):
    ''' store the counter for later use '''
    global counter
    counter = args

def analyze_data(args):
    ''' increment the global counter, do something with the input '''
    global counter
    # += operation is not atomic, so we need to get a lock:
    with counter.get_lock():
        counter.value += 1
    print counter.value
    return args * 10

if __name__ == '__main__':
    #inputs = os.listdir(some_directory)

    #
    # initialize a cross-process counter and the input lists
    #
    counter = Value('i', 0)
    inputs = [1, 2, 3, 4]

    #
    # create the pool of workers, ensuring each one receives the counter 
    # as it starts. 
    #
    p = Pool(initializer = init, initargs = (counter, ))
    i = p.map_async(analyze_data, inputs, chunksize = 1)
    i.wait()
    print i.get()

撰写回答