在android中解码tflite模型的输出时出现问题

2024-04-29 12:40:25 发布

您现在位置:Python中文网/ 问答频道 /正文

因此,我有一个项目,我必须:

  1. 训练神经网络。我发现这个手写文本识别项目有三种不同的体系结构。我选了“puigcerver”
  2. 将其转换为tflite型号
  3. 将其加载到Android应用程序中并获取输出

前两分进展顺利,但最后一分让我陷入困境。我可以得到一个输出(一个三维张量形状:[1][128][98]),但我不知道如何解码它

我有两个主要问题:

  1. tflite模型输出是一个3D浮点张量,其中每个[N]1D数组的98个值应表示128个字符句子中N个字符的字符集中每个字符的概率。然而,在本文中,作者声明字符集由95个字符组成:Article。所以第一个问题是,我有3个值(对于句子的每个字符),我没有预料到
  2. 除了最后一个值(第98个值)约为0.98/0.99和一些其他值约为0.002/0.004/0.008外,该三维张量的所有值都非常小(即2.15..e-24和更小)。如果我把它们当作概率来处理,搜索更高的值(不包括第96、97、98个值),我会得到像“llllqqqqqqqqggggggg oo…pppp…”这样的句子(显然是错的)

我尝试用原始网络(--image选项)推断出相同的图像,结果还可以,因此我认为可能在加载tflite模型或图像时出错。我还认为,也许我必须执行beam(或贪婪)搜索,而不仅仅是寻找更高的值

所以我的问题是,输出真的是一个带概率的张量,还是我遗漏了什么?如何以正确的方式解码输出?




TFlite转换:

import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
from tensorflow.keras.models import load_model
from network.model import HTRModel

keras.backend.clear_session()

model = load_model("checkpoint_weights.hdf5",custom_objects={'ctc_loss_lambda_func':HTRModel.ctc_loss_lambda_func}, compile=False)
model.compile(loss=HTRModel.ctc_loss_lambda_func)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.__setattr__('allow_custom_ops',True)

tflite = converter.convert()
open("tflite_model.tflite","wb").write(tflite)


模型和图像加载(在Android应用程序-Kotlin上)

...

@Throws(IOException::class)
fun initializeInterpreter() {

    // Load the TF Lite model
    val assetManager = context.assets
    println("Loading model file...")
    val model = loadModelFile(assetManager)

    println("Inizializing TF Lite interpreter...")
    // Initialize TF Lite Interpreter (with NNAPI enabled)
    val options = Interpreter.Options()
    //options.setUseNNAPI(true)
    val interpreter = Interpreter(model, options)

    // Read input shape from model file
    println("Reading input shape from model...")
    val inputShape = interpreter.getInputTensor(0).shape()
    inputImageWidth = inputShape[1]
    println("Shape 1: " + inputShape[1])
    inputImageHeight = inputShape[2]
    println("Shape 2: " + inputShape[2])
    modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
    println("Total input size: $modelInputSize")

    // Finish interpreter initialization
    this.interpreter = interpreter
    isInitialized = true
    Log.d(TAG, "Initialized TFLite interpreter.")
  }

  @Throws(IOException::class)
  private fun loadModelFile(assetManager: AssetManager): ByteBuffer {
    val fileDescriptor = assetManager.openFd(MODEL_FILE)
    val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
    val fileChannel = inputStream.channel
    val startOffset = fileDescriptor.startOffset
    val declaredLength = fileDescriptor.declaredLength
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
  }

  fun classify(bitmap: Bitmap): String{
    if (!isInitialized) {
      throw IllegalStateException("TF Lite Interpreter is not initialized yet.")
    }

    /*var startTime: Long
    var elapsedTime: Long*/

    // Preprocessing: resize the input

    println("Preprocessing - Resizing the input..")
    //startTime = System.nanoTime()
    val resizedImage = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true)
    val byteBuffer = convertBitmapToByteBuffer(resizedImage)
    //elapsedTime = (System.nanoTime() - startTime) / 1000000
    //Log.d(TAG, "Preprocessing time = " + elapsedTime + "ms")

    //startTime = System.nanoTime()

    println("Running interpreter...")
    val result = Array(1) { Array(128) { FloatArray(98)} }
    interpreter?.run(byteBuffer, result)
    //elapsedTime = (System.nanoTime() - startTime) / 1000000
    //Log.d(TAG, "Inference time = " + elapsedTime + "ms")*/

    println("Result: $result")
    predicted = result
    val r = result[0]
    var i : Int
    var out = "["
    for (i in 0..127)
    {
      out += getOutputString(r[i])
    }

    out += "]"
    //return out

    return decodeSentece()

  }

  fun classifyAsync(bitmap: Bitmap): Task<String> {
    return call(executorService, Callable<String> { classify(bitmap) })
  }

  fun close() {
    call(
      executorService,
      Callable<String> {
        interpreter?.close()
        Log.d(TAG, "Closed TFLite interpreter.")
        null
      }
    )
  }

  private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
    val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
    byteBuffer.order(ByteOrder.nativeOrder())

    val pixels = IntArray(inputImageWidth * inputImageHeight)
    bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

    for (pixelValue in pixels) {
      val r = (pixelValue shr 16 and 0xFF)
      val g = (pixelValue shr 8 and 0xFF)
      val b = (pixelValue and 0xFF)

      // Convert RGB to grayscale and normalize pixel value to [0..1]
      val normalizedPixelValue = (r + g + b) / 3.0f / 255.0f
      byteBuffer.putFloat(normalizedPixelValue)
    }

    return byteBuffer
  }

这些是(仅供参考)我得到的输出张量的前三行:

[[ 1.02767E-9 , 1.0279699E-9 , 3.139572E-4 , 2.1537318E-4 , 0.002740169 , 3.247818E-5 , 1.0521399E-5 , 2.9409595E-4 , 3.4732078E-4 , 1.8211146E-6 , 1.8896988E-5 , 1.0662472E-5 , 0.30825683 , 1.9493811E-5 , 0.15347835 , 7.209303E-5 , 0.10027219 , 1.7075523E-5 , 1.2530926E-5 , 0.011397267 , 0.004646272 , 2.7380593E-6 , 1.21792706E-4 , 0.0030580924 , 0.0037772718 , 0.01746696 , 7.144082E-4 , 2.8029652E-5 , 5.0955594E-5 , 0.030346025 , 6.488925E-4 , 4.2334714E-4 , 0.0010043866 , 2.5545945E-4 , 0.0057575954 , 9.177484E-4 , 8.877261E-6 , 8.9535155E-5 , 1.8161612E-4 , 9.084007E-5 , 0.16994531 , 5.3004638E-5 , 0.0055839983 , 2.0921101E-4 , 5.918604E-4 , 1.8135815E-4 , 1.8600904E-4 , 1.3010694E-5 , 7.701663E-5 , 0.011575605 , 0.001689526 , 7.148706E-5 , 4.6759746E-5 , 4.6543145E-4 , 8.721436E-7 , 0.007293349 , 1.371512E-4 , 7.3025643E-4 , 1.2985134E-4 , 2.2911412E-5 , 5.126359E-4 , 6.7653934E-7 , 8.5107595E-6 , 2.8127617E-6 , 1.8223985E-6 , 0.0017414617 , 5.9720107E-5 , 1.0222625E-9 , 1.095164E-9 , 4.0091674E-8 , 6.4829896E-5 , 0.0028546597 , 9.027818E-7 , 1.0092357E-9 , 1.0142923E-8 , 2.4069648E-5 , 0.020749098 , 3.682649E-4 , 2.9179654E-7 , 2.4412808E-5 , 3.0341505E-6 , 1.0448614E-9 , 9.900646E-10 , 9.750201E-10 , 3.0151805E-5 , 1.0438933E-9 , 1.0359494E-9 , 1.0682695E-9 , 1.0451963E-9 , 1.0827725E-9 , 1.0221299E-9 , 1.019132E-9 , 9.502418E-10 , 9.924109E-10 , 1.0422022E-9 , 1.0894825E-9 , 0.019479617 , 0.107967034]
[ 4.1440778E-15 , 4.1381066E-15 , 2.8557254E-11 , 1.8157608E-14 , 8.410285E-15 , 2.9407989E-16 , 1.060987E-16 , 2.4269218E-14 , 7.374521E-16 , 1.5804724E-18 , 5.3503064E-16 , 3.1957313E-14 , 3.8080086E-6 , 3.4774015E-13 , 9.287003E-7 , 6.938215E-10 , 1.06444844E-4 , 1.3310229E-11 , 1.4083793E-11 , 6.8555856E-8 , 8.523E-8 , 1.7820193E-15 , 8.969676E-11 , 1.0272463E-8 , 3.4963513E-8 , 7.1322137E-7 , 1.10203715E-7 , 3.426864E-11 , 2.6655234E-14 , 2.8331244E-5 , 6.8602155E-8 , 2.5318696E-9 , 2.9697837E-8 , 2.6458968E-9 , 3.169576E-9 , 9.308899E-10 , 1.7746693E-11 , 2.1560531E-16 , 3.1369616E-12 , 2.1284496E-15 , 8.978029E-11 , 3.51387E-14 , 1.712866E-11 , 3.6545718E-14 , 2.1302052E-14 , 1.5469606E-12 , 1.8957156E-12 , 5.3142797E-17 , 3.2412155E-16 , 3.797056E-12 , 8.453629E-11 , 5.663211E-12 , 1.2467538E-14 , 1.8673284E-13 , 1.1108751E-15 , 1.6332011E-10 , 1.0867543E-13 , 1.4709664E-11 , 1.1832392E-13 , 3.5229395E-16 , 2.2621423E-12 , 1.4893199E-14 , 1.1778153E-16 , 4.108518E-16 , 2.5608845E-16 , 2.2514338E-11 , 6.590058E-16 , 4.1018955E-15 , 4.210221E-15 , 4.0438975E-15 , 1.7089807E-13 , 1.7934002E-16 , 3.64835E-19 , 4.5916805E-15 , 7.155464E-15 , 2.0624084E-12 , 5.2085922E-9 , 9.621887E-10 , 2.6019372E-16 , 1.3379034E-14 , 5.9298675E-16 , 4.35525E-15 , 4.027287E-15 , 4.065149E-15 , 3.1142788E-15 , 4.126221E-15 , 4.1317653E-15 , 4.5336524E-15 , 4.308452E-15 , 4.2372895E-15 , 3.9838916E-15 , 4.3225607E-15 , 3.967846E-15 , 4.1830697E-15 , 4.086403E-15 , 4.2691384E-15 , 0.003423732 , 0.99643564]
[ 1.1704519E-17 , 1.1569569E-17 , 5.2640887E-15 , 2.9042624E-19 , 4.888624E-22 , 8.800831E-24 , 2.2810482E-22 , 1.0456246E-20 , 4.613259E-23 , 2.9788603E-24 , 4.2340931E-22 , 5.365159E-19 , 1.9403133E-8 , 3.617449E-18 , 8.07217E-9 , 3.6707356E-12 , 6.9733637E-6 , 1.00665275E-14 , 1.2904081E-13 , 2.300352E-10 , 3.3837338E-10 , 2.1202834E-21 , 4.4093415E-14 , 2.4844083E-11 , 1.8735478E-9 , 2.7742542E-8 , 1.5738684E-9 , 2.9336394E-13 , 9.488479E-19 , 2.5946217E-6 , 2.9507443E-9 , 1.29704225E-11 , 1.9227062E-9 , 1.7503684E-11 , 4.312201E-12 , 1.6009424E-12 , 2.0033327E-13 , 1.4647874E-23 , 4.3650106E-17 , 2.8878068E-23 , 1.3077074E-17 , 1.0820965E-21 , 1.6695615E-17 , 3.3014689E-21 , 8.4159397E-22 , 5.0169094E-19 , 1.9141252E-17 , 5.20164E-26 , 2.0089118E-24 , 2.2912117E-19 , 1.5849434E-15 , 4.2030135E-17 , 2.2811363E-21 , 2.2500594E-20 , 7.502992E-21 , 7.9013944E-16 , 5.3322657E-20 , 4.289789E-16 , 2.5743777E-20 , 1.5604532E-24 , 1.4845246E-18 , 2.1423116E-18 , 1.1037277E-24 , 1.1757478E-21 , 1.5670599E-20 , 7.630421E-15 , 9.0307036E-23 , 1.2181271E-17 , 1.18615665E-17 , 8.01336E-19 , 4.587686E-17 , 2.1303416E-24 , 4.3491843E-26 , 1.3229859E-17 , 1.1444598E-17 , 2.6614466E-13 , 3.7416266E-9 , 2.5167621E-9 , 3.0683186E-20 , 5.911551E-17 , 1.07436035E-20 , 1.1928041E-17 , 1.2051209E-17 , 1.1595416E-17 , 1.8123445E-20 , 1.1205827E-17 , 1.1453464E-17 , 1.3506912E-17 , 1.2615066E-17 , 1.1414165E-17 , 1.01463546E-17 , 1.2532566E-17 , 1.1185157E-17 , 1.1822221E-17 , 1.1537131E-17 , 1.184172E-17 , 0.1587565 , 0.84123385 ]
...
]

抱歉这个愚蠢的问题,但我是新的Tensorflow和Tensorflow Lite。。。 先谢谢你


Tags: fromimportinputmodeltensorflowvallitekeras