第二个模型预测的MXNet超时

2024-04-20 04:23:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在设置一个flask服务器,它加载我的mxnet模型,并有一个predict-Api方法。你知道吗

在测试api时,我注意到预测在mxnetapi中的第二个调用上有一个超时。我所说的超时是指python被困在一个mxnet方法中,并且似乎没完没了地运行。你知道吗

我使用的是python flask和mxnet v.1.4.1。 我尝试升级到mxnet v.1.5.0,但是没有任何改变,错误仍然存在。你知道吗

我已经为predict方法尝试了不同的实现(见下文),但是都超时了。当我切换到keras后端时,一切正常,但我需要使用mxnet。你知道吗

我用这个指南来预测:https://mxnet.incubator.apache.org/versions/master/tutorials/python/predict_image.html

import mxnet as mx
from collections import namedtuple

Batch = namedtuple('Batch', ['data'])

class MxnetBackend:
    def __init__(self):
        print("MXNet Version:", mx.__version__)
        self.sym, self.arg_params, self.aux_params = mx.model.load_checkpoint(prefix='models/kc_mxnet', epoch=0)
        self.mod = mx.mod.Module(symbol=self.sym, 
                    data_names=['/dense_1_input1'], 
                    context=mx.cpu(), 
                    label_names=None)
        self.mod.bind(for_training=False, 
                data_shapes=[('/dense_1_input1', (1, 1, 512, 3010))], 
                label_shapes=self.mod._label_shapes)
        self.mod.set_params(self.arg_params, self.aux_params, allow_missing=True)

    def predict(self, X):#this timeouts on second call
        """
        gets ndarray
        returns ndarray
        """
        X = mx.nd.array(X)
        self.mod.forward(Batch(X))
        res = self.mod.get_outputs()[0].asnumpy()
        return res

    def predict2(self, X):#this timeouts on second call, too
        """
        gets ndarray
        returns ndarray
        """
        return self.mod.predict(X).asnumpy()

在第一次通话中,mxnet工作正常。 我希望第二次调用mxnet时返回值。 我该怎么解决这个问题?你知道吗


Tags: 方法importselfmodflaskdatadefbatch