我正在使用apachebeampythonsdk和Dataflow编写一个推理管道,用于使用TensorFlow模型进行预测。我在DoFn中有预测步骤,但我不希望每次处理包时都要加载模型,因为这非常昂贵。从docshere,“如果需要,在worker上创建参数DoFn的新实例,并且DoFn.设置方法被调用。这可以通过反序列化或其他方式实现。PipelineRunner可以为多个bundle重用DoFn实例。异常终止的DoFn(通过抛出异常)永远不会被重用
class StatefulGetEmbeddingsDoFn(beam.DoFn):
def __init__(self, model_dir):
self.model = None # initialize
self.model_dir = model_dir
def process(self, element):
if not self.model: # load model if model hasn't been loaded yet
global i
i += 1
logging.info('Getting model: {}'.format(i))
self.model = Model(saved_model_dir=self.model_dir)
ids, b64 = element
embeddings = self.model.predict(b64)
res = [
{
'image': _id,
'embeddings': embedding.tolist()
} for _id, embedding in zip(ids, embeddings)
]
return res
似乎每个工人都在不止一次地加载这个模型(我有一个约30-40台机器的集群)。有没有办法防止模型被多次加载?我本以为这个DoFn只能在每台机器上构造一次,但从日志来看,似乎不是这样。。。你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐