在Python中使用全局变量进行多进程的正确方法

0 投票
1 回答
66 浏览
提问于 2025-04-14 17:07

我正在处理一个很大的数据集,为了让我的函数正常工作,我需要把数据集分成几部分,然后逐批进行计算。以下是我的代码:

batch_size = 128
results = []

for i in range(0, len(queries), batch_size):
  
  result = linear_kernel(query[i:i+batch_size], dataset)
  results.append(result)

运行这个代码大约需要5个小时。

现在我想使用多进程来加快速度。所以我定义了一个工作函数:

这里的query和dataset是一个稀疏矩阵,使用的是TFIDF向量化方法。

def job(i):
  
  return linear_kernel(query[i: i+batch_size], dataset)


with concurrent.futures.ProcessPoolExecutor() as executor:
    results = executor.map(job, tqdm(range(0, len(query), batch_size)))
 

然后问题来了:

  • 我不知道执行器中的结果是什么,我猜批次可能会被打乱。因为我还需要处理这些结果,所以我需要确保结果的行索引和query数据的行索引是匹配的。我该怎么做呢?我不知道怎么修改输出,以保持行索引i的信息。

  • 其次,使用两个在工作函数外部的变量querydataset是否可以?我对多进程了解不多,如果它在不同的CPU上运行,那么每个处理器会复制数据吗?

1 个回答

1

你其实不需要一个叫做 job 的函数。你的“工作”函数是 linear_kernel。下面是我的代码:

import logging
import random
import time
from concurrent.futures import ProcessPoolExecutor

logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)-8s | %(processName)-14s | %(funcName)-14s | %(message)s",
)


def linear_kernel(query: list, left: int, right: int, dataset):
    logging.debug("Processing batch [%d:%d]", left, right)

    # Fake calculation, which takes a long time to complete
    time.sleep(random.randint(1, 5))
    result = left + right

    logging.debug("Return %r", result)
    return result


def main():
    # Fake data
    query = list(range(20))
    dataset = None

    batch_size = 6
    futures = []

    with ProcessPoolExecutor() as executor:
        for left_index in range(0, len(query), batch_size):
            # do not pass into the function query[i:i+batch_size]
            # because that is a slice notation, which creates a
            # new array in memory. Instead, pass in the array `query`
            # the left and right indices
            futures.append(
                executor.submit(
                    linear_kernel,
                    query=query,
                    left=left_index,
                    right=left_index + batch_size,
                    dataset=dataset,
                )
            )

    results = [future.result() for future in futures]
    logging.debug("results=%r", results)


if __name__ == "__main__":
    main()

示例输出:

DEBUG    | SpawnProcess-1 | linear_kernel  | Processing batch [0:6]
DEBUG    | SpawnProcess-3 | linear_kernel  | Processing batch [6:12]
DEBUG    | SpawnProcess-2 | linear_kernel  | Processing batch [12:18]
DEBUG    | SpawnProcess-4 | linear_kernel  | Processing batch [18:24]
DEBUG    | SpawnProcess-4 | linear_kernel  | Return 42
DEBUG    | SpawnProcess-1 | linear_kernel  | Return 6
DEBUG    | SpawnProcess-3 | linear_kernel  | Return 18
DEBUG    | SpawnProcess-2 | linear_kernel  | Return 30
DEBUG    | MainProcess    | main           | results=[6, 18, 30, 42]

注意事项

  • 我想强调的是,不要像这样创建一个切片 query[i:i+batch_size] 然后把它传给 linear_kernel() 函数。这个切片在内存中其实是一个独立的列表。如果 query 很大而且 batch_size 也很大,我们可能会占用很多内存。因此,最好直接传入 query,同时传入索引,而不是创建那些切片。
  • 从输出中可以看到,虽然有些计算是乱序完成的,但我们仍然得到了有序的结果。这是因为列表 futures 为我们跟踪了顺序。
  • 当然,我不知道你的数据是什么样的,更不知道你的计算是怎么进行的,所以我只能假装这些数据和计算。

撰写回答