如何缩短预训练模型的加载时间?

2024-03-28 20:08:12 发布

您现在位置:Python中文网/ 问答频道 /正文

使用ResNet50加载“Imagenet”的权重时,每次加载权重几乎需要10-11秒。 有没有办法缩短装载时间

代码:

from flask import Flask, render_template, request
from werkzeug import secure_filename
from flask import request,Flask
import json
import os
import time

from keras.preprocessing import image as image_util 
from keras.applications.imagenet_utils import preprocess_input
from keras.applications.imagenet_utils import decode_predictions
# from keras.applications import ResNet50
from keras.applications.inception_v3 import InceptionV3
import numpy as np

app = Flask(__name__)

@app.route('/object_rec', methods=['POST'])
def object_rec():

      f = request.files['file']
      file_path = ("./upload/"+secure_filename(f.filename))
      f.save(file_path)
      image = image_util.load_img(file_path,target_size=(299,299))
      image = image_util.img_to_array(image)
      image = np.expand_dims(image,axis=0) #(224,224,3) --> (1,224,224,3)
      image = preprocess_input(image)

      start_time = time.time()
      model = InceptionV3(weights="imagenet")
      pred = model.predict(image)
      p = decode_predictions(pred)

      ans = p[0][0]
      acc = ans[2]
      acc = str(acc)
      if ans[1] == "Granny_Smith":
            ans = ans[1]
            ans = 'Apple'
      else:
            ans = ans[1]
      print("THE PREDICTED IMAGE IS: "+ans)
      print("THE ACCURACY IS: "+acc)
      print("--- %s seconds ---" % (time.time() - start_time))
      result = {
            "status": True,
            "object": ans,
            "score":acc
      }
      result = json.dumps(result)
      return result

if __name__ == '__main__':
   app.run(host='0.0.0.0',port=6000,debug=True)

所用时间在8-11秒之间。 如果它在3-4秒内加载模型并进行分类,我会很好

提前谢谢


Tags: fromimageimportflasktimerequestutilresult
1条回答
网友
1楼 · 发布于 2024-03-28 20:08:12

您可以这样做,即在特定会话中加载模型,然后每次您想要使用该模型时,只需设置该特定会话,然后只需调用predict以预测您需要它的位置:

app = Flask(__name__)
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()

# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras! 
# Otherwise, their weights will be unavailable in the threads after the 
session there has been set
set_session(sess)

model = InceptionV3(weights="imagenet")

@app.route('/object_rec', methods=['POST'])
def object_rec():
   global sess
   global graph
   with graph.as_default():
      set_session(sess)
      model.predict(...)

if __name__ == '__main__':
   app.run(host='0.0.0.0',port=6000,debug=True)

相关问题 更多 >