DoFn构造多少次?

2024-05-12 23:29:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用apachebeampythonsdk和Dataflow编写一个推理管道,用于使用TensorFlow模型进行预测。我在DoFn中有预测步骤,但我不希望每次处理包时都要加载模型,因为这非常昂贵。从docshere,“如果需要,在worker上创建参数DoFn的新实例,并且DoFn.设置方法被调用。这可以通过反序列化或其他方式实现。PipelineRunner可以为多个bundle重用DoFn实例。异常终止的DoFn(通过抛出异常)永远不会被重用

class StatefulGetEmbeddingsDoFn(beam.DoFn):
    def __init__(self, model_dir):
         self.model = None # initialize
         self.model_dir = model_dir

    def process(self, element):
         if not self.model: # load model if model hasn't been loaded yet
             global i
             i += 1
             logging.info('Getting model: {}'.format(i))
             self.model = Model(saved_model_dir=self.model_dir)


         ids, b64 = element
         embeddings = self.model.predict(b64)

         res = [
            {
                'image': _id,
                'embeddings': embedding.tolist()
            } for _id, embedding in zip(ids, embeddings)
         ]
         return res

似乎每个工人都在不止一次地加载这个模型(我有一个约30-40台机器的集群)。有没有办法防止模型被多次加载?我本以为这个DoFn只能在每台机器上构造一次,但从日志来看,似乎不是这样。。。你知道吗


Tags: 实例模型selfididsmodelifdef