如何在Python中多线程读取对象列表的函数?天体物理示例代码

0 投票
2 回答
725 浏览
提问于 2025-04-18 03:02

这是我第一次在Stack Overflow发帖。我会尽量提供所有必要的信息,但如果还有其他信息可以帮助我更清楚地表达我的问题,请告诉我。

我正在尝试使用pool.map来多线程处理一个在天体物理代码中比较耗时的函数。这个函数的输入是一个对象列表。基本的代码结构是这样的:

有一个代表星星的类,里面包含一些物理属性:

Class Stars:
    def __init__(self,mass,metals,positions,age):
        self.mass = mass
        self.metals = metals
        self.positions = positions
        self.age = age
    def info(self):
        return(self.mass,self.metals,self.positions,self.age)

还有一个这些对象的列表:

stars_list = []
for i in range(nstars):
                stars_list.append(Stars(mass[i],metals[i],positions[i],age[i]))

(其中质量、金属含量、位置和年龄是从另一个脚本中获得的)。

我有一个耗时的函数,它会对这些星星对象进行处理,并返回每个星星的光谱:

def newstars_gen(stars_list):
   ....
   return stellar_nu,stellar_fnu

其中stellar_nu和stellar_fnu是numpy数组。

我想做的是把星星对象的列表(stars_list)分成几块,然后在多个线程上运行newstars_gen函数,以提高速度。所以,我把列表分成了三个子列表,然后尝试通过pool.map来运行我的函数:

p = Pool(processes = 3)
nchunks = 3
chunk_start_indices = []
chunk_start_indices.append(0) #the start index is 0

delta_chunk_indices = nstars / nchunks

for n in range(1,nchunks):
    chunk_start_indices.append(chunk_start_indices[n-1]+delta_chunk_indices)

for n in range(nchunks):
    stars_list_chunk = stars_list[chunk_start_indices[n]:chunk_start_indices[n]+delta_chunk_indices]
    #if we're on the last chunk, we might not have the full list included, so need to make sure that we have that here
    if n == nchunks-1: 
        stars_list_chunk = stars_list[chunk_start_indices[n]:-1]


    chunk_sol = p.map(newstars_gen,stars_list_chunk)

但是当我这样做时,我遇到了以下错误:

File "/Users/[username]/python2.7/multiprocessing/pool.py", line 250, in map
    return self.map_async(func, iterable, chunksize).get()
  File "/Users/[username]/python2.7/multiprocessing/pool.py", line 554, in get
    raise self._value
AttributeError: Stars instance has no attribute '__getitem__'

所以,我对Stars类应该包含什么样的属性感到困惑。我尝试在网上查阅相关资料,但不太确定如何为这个类定义合适的__getitem__。我对面向对象编程(以及Python整体)还很陌生。

任何帮助都非常感谢!

2 个回答

1

我写了一个函数,可以把处理一个可迭代对象(比如你的星星对象列表)分配给多个处理器,这个方法我觉得对你会很有帮助。

from multiprocessing import Process, cpu_count, Lock
from sys import stdout 
from time import clock

def run_multicore_function(iterable, function, func_args = [], max_processes = 0):
    #directly pass in a function that is going to be looped over, and fork those 
    #loops onto independant processors. Any arguments the function needs must be provided as a list.     
    if max_processes == 0:
        cpus = cpu_count()
        if cpus > 7:
            max_processes = cpus - 3
        elif cpus > 3:
            max_processes = cpus - 2
        elif cpus > 1:
            max_processes = cpus - 1
        else:
            max_processes = 1

    running_processes = 0
    child_list = []
    start_time = round(clock())
    elapsed = 0
    counter = 0
    print "Running function %s() on %s cores" % (function.__name__,max_processes)
    #fire up the multi-core!!
    stdout.write("\tJob 0 of %s" % len(iterable),)
    stdout.flush()
    for next_iter in iterable:
       if type(iterable) is dict:
           next_iter = iterable[next_iter]
       while 1:     #Only fork a new process when there is a free processor. 
            if running_processes < max_processes:
                #Start new process                  
                stdout.write("\r\tJob %s of %s (%i sec)" % (counter,len(iterable),elapsed),)
                stdout.flush()                  
                if len(func_args) == 0: 
                    p = Process(target=function, args=(next_iter,))
                else:
                    p = Process(target=function, args=(next_iter,func_args))
                p.start()
                child_list.append(p)
                running_processes += 1
                counter += 1
                break
            else:
                #processor wait loop
                while 1:
                    for next in range(len(child_list)):
                        if child_list[next].is_alive():
                            continue
                        else:
                            child_list.pop(next)
                            running_processes -= 1
                            break
                    if (start_time + elapsed) < round(clock()):
                        elapsed = round(clock()) - start_time
                        stdout.write("\r\tJob %s of %s (%i sec)" % (counter,len(iterable),elapsed),)
                        stdout.flush()

                    if running_processes < max_processes:
                        break

    #wait for remaining processes to complete --> this is the same code as the processor wait loop above
    while len(child_list) > 0:
        for next in range(len(child_list)):
            if child_list[next].is_alive():
                continue
            else:
                child_list.pop(next)
                running_processes -= 1
                break  #need to break out of the for-loop, because the child_list index is changed by pop 
        if (start_time + elapsed) < round(clock()):
            elapsed = round(clock()) - start_time
            stdout.write("\r\tRunning job %s of %s (%i sec)" % (counter,len(iterable),elapsed),)
            stdout.flush()

    print " --> DONE\n"
    return  

作为使用示例,我们可以用你的星星列表,然后把新生成的星星结果发送到一个共享文件里。首先,设置好你的可迭代对象、文件和文件锁。

   star_list = []
   for i in range(nstars):
        stars_list.append(Stars(mass[i],metals[i],positions[i],age[i]))

   outfile = "some/where/output.txt"
   file_lock = Lock()

像这样定义你的耗时函数:

def newstars_gen(stars_list_item,args):   #args = [outfile,file_lock]
    outfile,file_lock = args

        ....

    with file_lock:
        with open(outfile,"a") as handle:
             handle.write(stellar_nu,stellar_fnu)

现在把你的星星列表传入run_multicore_function()函数中。

run_multicore_function(star_list, newstars_gen, [outfile,file_lock])

当所有的项目都计算完毕后,你可以回到输出文件中获取数据,然后继续进行。除了写入文件,你也可以使用multiprocessing.Value或multiprocessing.Array来共享状态,但我遇到过一些问题,就是如果我的列表很大,而我调用的函数又比较快,有时数据会丢失。也许其他人能找到原因。

希望这些内容对你有帮助!
祝好运,
-Steve

1

看起来这里可能有几个问题,可以进行一些清理或者让代码更符合Python的风格。不过,最关键的问题是你错误地使用了pool.multiprocessing.Pool.map。你的newstars_gen函数是期待一个列表的,但p.map会把你给它的列表分成小块,每次只处理一个Star。你可能需要重写newstars_gen,让它一次只处理一个星星,然后把你最后代码块中的第一行和最后一行以外的都去掉。如果newstars_gen中的计算不是独立的(比如,一个星星的质量会影响另一个星星的计算),那么你就需要进行更大幅度的重构。

另外,学习一下列表推导式也会对你有帮助。要知道,其他内置的数据结构(比如setdict)也有类似的用法,同时也可以看看生成器推导式

撰写回答