Python 多进程:pool.join() 后计数丢失?

2 投票
1 回答
804 浏览
提问于 2025-04-17 22:23

我正在尝试解决一个问题,主要是存储特定长度的子字符串的位置和数量。因为字符串可能很长(比如基因组序列),所以我想用多个进程来加快速度。但是在程序运行的时候,存储这些信息的变量似乎在线程结束后就丢失了所有信息。

import numpy
import multiprocessing
from multiprocessing.managers import BaseManager, DictProxy
from collections import defaultdict, namedtuple, Counter
from functools import partial
import ctypes as c

class MyManager(BaseManager):
        pass

MyManager.register('defaultdict', defaultdict, DictProxy)

def gc_count(seq):
        return int(100 * ((seq.upper().count('G') + seq.upper().count('C') + 0.0) / len(seq)))

def getreads(length, table, counts, genome):
        genome_len = len(genome)
        for start in range(0,genome_len): 
                gc = gc_count(genome[start:start+length])
                table[ (length, gc) ].append( (start) )
                counts[length,gc] +=1

if __name__ == "__main__":
    g = 'ACTACGACTACGACTACGCATCAGCACATACGCATACGCATCAACGACTACGCATACGACCATCAGATCACGACATCAGCATCAGCATCACAGCATCAGCATCAGCACTACAGCATCAGCATCAGCATCAG'
    genome_len = len(g)

    mgr = MyManager()
    mgr.start()
    m = mgr.defaultdict(list)
    mp_arr = multiprocessing.Array(c.c_double, 10*101)
    arr = numpy.frombuffer(mp_arr.get_obj())
    count = arr.reshape(10,101)

    pool = multiprocessing.Pool(9)
    partial_getreads = partial(getreads, table=m, counts=count, genome=g)
    pool.map(partial_getreads, range(1, 10))
    pool.close()
    pool.join()

    for i in range(1, 10):
            for j in range(0,101):
                    print count[i,j]
    for i in range(1, 10):
            for j in  range(0,101):
                    print len(m[(i,j)])

在最后的循环中,打印出来的每个 count 元素都是 0.0,而 m 中的每个列表都是 0,所以我不知道为什么所有的计数都没了。如果我在 getreads(...) 函数中打印计数,我能看到这些值在增加。可是,如果在 getreads(...) 中打印 len(table[ (length, gc) ]) 或在主程序中打印 len(m[(i,j)]),结果都是 0

1 个回答

1

你也可以把你的问题看作一个“映射-归约”的问题,这样就可以避免在多个进程之间共享数据(我想这样会加快计算速度)。你只需要从函数(映射)中返回结果表和计数,然后把所有进程的结果合并起来(归约)。

回到你最初的问题……

Managers的底部,有一条关于可变值或字典和列表中项目修改的相关说明。基本上,你需要把修改过的对象重新分配给容器代理。

l = table[ (length, gc) ]
l.append( (start) )
table[ (length, gc) ] = l

还有一个相关的Stackoverflow帖子,讨论了如何将池映射与数组结合

考虑到这两点,你可以做类似这样的事情:

def getreads(length, table, genome):
        genome_len = len(genome)

        arr = numpy.frombuffer(mp_arr.get_obj())
        counts = arr.reshape(10,101)

        for start in range(0,genome_len): 
                gc = gc_count(genome[start:start+length])
                l = table[ (length, gc) ]
                l.append( (start) )
                table[ (length, gc) ] = l
                counts[length,gc] +=1


if __name__ == "__main__":
    g = 'ACTACGACTACGACTACGCATCAGCACATACGCATACGCATCAACGACTACGCATACGACCATCAGATCACGACATCAGCATCAGCATCACAGCATCAGCATCAGCACTACAGCATCAGCATCAGCATCAG'
    genome_len = len(g)

    mgr = MyManager()
    mgr.start()
    m = mgr.defaultdict(list)
    mp_arr = multiprocessing.Array(c.c_double, 10*101)
    arr = numpy.frombuffer(mp_arr.get_obj())
    count = arr.reshape(10,101)

    pool = multiprocessing.Pool(9)
    partial_getreads = partial(getreads, table=m, genome=g)

    pool.map(partial_getreads, range(1, 10))
    pool.close()
    pool.join()

    arr = numpy.frombuffer(mp_arr.get_obj())
    count = arr.reshape(10,101)

撰写回答