在Keras中对生成器进行多处理时,如何为每个fork建立单独的数据库连接?

2024-04-24 12:02:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我将Keras与fit_generator()一起使用。我的生成器连接到一个数据库(在我的例子中是MongoDB)来获取每个批的数据。如果我使用fit_generator()的多处理标志,我会得到以下警告:

UserWarning: MongoClient opened before fork. Create MongoClient only after forking.

我在__init__()期间连接到数据库:

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        collection = MagicMongoDBConnector()

    def __len__(self):
        ...

    def __getitem__(self, idx):
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...

我假设每个epoch需要一个单独的连接,但不幸的是,没有可用的on_epoch_begin(self)回调(如here)。你知道吗

所以有两个问题:
如果使用多处理,Keras如何以及何时分叉生成器? 如何摆脱MongoClient警告并在每个fork内部连接?你知道吗


Tags: self数据库警告initonmongodbdeffork
2条回答

我没有mongo DB可供测试,但这可能有用-您可以获取集合(连接?)在每个进程的第一个get项上。你知道吗

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        self.collection = None

    def __len__(self):
        ...

    def __getitem__(self, idx):
        if self.collection is None:
            self.collection = MagicMongoDBConnector()
        # Continue with your code
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...

如果您使用的是python3.7,那么可以使用os.register_at_fork触发创建数据库连接

例如,您可以执行以下操作:

from os import register_at_fork

def reinit_dbcon():
    generator_obj.collection = MagicMongoDBConnector()

register_at_fork(after_in_child=reinit_dbcon)

在你呼叫fit_generator之前的某个地方。假设对象是全局的

相关问题 更多 >