我是从一个图像数据集创建模型的,我的数据集是Tensorflow Lite版本“tflite”。当我尝试使用android对象检测时,我遇到了以下错误: 无法在UINT8类型的TensorFlowLite tensor和[[[F]类型的Java对象(与TensorFlowLite类型FLOAT32兼容)之间转换。
注意:我的模型是量化的,我用python中的tensorflow tflite模型生成器构建了它
我的tflite模型输入:
[{'name': 'input_1', 'index': 178, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'shape_signature': array([ -1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
模型输出:
[{'name': 'Identity', 'index': 179, 'shape': array([ 1, 131], dtype=int32), 'shape_signature': array([ -1, 131], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0), 'quantization_parameters': {'scales': array([0.00390625], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
我尝试了以下代码(Android Java):
private static final String TAG = "TFLiteObjectDetectionAPIModelWithInterpreter";
// Only return this many results.
private static final int NUM_DETECTIONS = 131;
// Float model
private static final float IMAGE_MEAN = 127.5f;
private static final float IMAGE_STD = 127.5f;
// Number of threads in the java app
private static final int NUM_THREADS = 4;
private boolean isModelQuantized;
// Config values.
private int inputSize;
// Pre-allocated buffers.
private final List<String> labels = new ArrayList<>();
private int[] intValues;
// outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
// contains the location of detected boxes
private float[][][] outputLocations;
//private float[][][] outputLocations;
// outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
// contains the classes of detected boxes
private float[][] outputClasses;
// outputScores: array of shape [Batchsize, NUM_DETECTIONS]
// contains the scores of detected boxes
private float[][] outputScores;
// numDetections: array of shape [Batchsize]
// contains the number of detected boxes
private float[] numDetections;
private ByteBuffer imgData;
private MappedByteBuffer tfLiteModel;
private Interpreter.Options tfLiteOptions;
private Interpreter tfLite;
private TFLiteObjectDetectionAPIModel() {}
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
throws IOException {
AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param modelFilename The model file path relative to the assets folder
* @param labelFilename The label file path relative to the assets folder
* @param inputSize The size of image input
* @param isQuantized Boolean representing model is quantized or not
*/
public static Detector create(
final Context context,
final String modelFilename,
final String labelFilename,
final int inputSize,
final boolean isQuantized)
throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
MappedByteBuffer modelFile = loadModelFile(context.getAssets(), modelFilename);
MetadataExtractor metadata = new MetadataExtractor(modelFile);
try (BufferedReader br =
new BufferedReader(
new InputStreamReader(
metadata.getAssociatedFile(labelFilename), Charset.defaultCharset()))) {
String line;
while ((line = br.readLine()) != null) {
Log.w(TAG, line);
d.labels.add(line);
}
}
d.inputSize = inputSize;
try {
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(NUM_THREADS);
d.tfLite = new Interpreter(modelFile, options);
d.tfLiteModel = modelFile;
d.tfLiteOptions = options;
} catch (Exception e) {
throw new RuntimeException(e);
}
d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel;
if (isQuantized) {
numBytesPerChannel = 1; // Quantized
} else {
numBytesPerChannel = 4; // Floating point
}
d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
// d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * 1);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
d.outputLocations = new float[1][NUM_DETECTIONS][4];
d.outputClasses = new float[1][NUM_DETECTIONS];
d.outputScores = new float[1][NUM_DETECTIONS];
d.numDetections = new float[1];
return d;
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
if (isModelQuantized) {
// Quantized model
imgData.put((byte) ((pixelValue >> 16) & 0xFF));
imgData.put((byte) ((pixelValue >> 8) & 0xFF));
imgData.put((byte) (pixelValue & 0xFF));
} else { // Float model
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
outputLocations = new float[1][NUM_DETECTIONS][4];
outputClasses = new float[1][NUM_DETECTIONS];
outputScores = new float[1][NUM_DETECTIONS];
numDetections = new float[1];
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
outputMap.put(2, outputScores);
outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
// Show the best detections.
// after scaling them back to the input size.
// You need to use the number of detections from the output and not the NUM_DETECTONS variable
// declared on top
// because on some models, they don't always output the same total number of detections
// For example, your model's NUM_DETECTIONS = 20, but sometimes it only outputs 16 predictions
// If you don't use the output's numDetections, you'll get nonsensical data
int numDetectionsOutput =
min(
NUM_DETECTIONS,
(int) numDetections[0]); // cast from float to integer, use min for safety
final ArrayList<Recognition> recognitions = new ArrayList<>(numDetectionsOutput);
for (int i = 0; i < numDetectionsOutput; ++i) {
final RectF detection =
new RectF(
outputLocations[0][i][1] * inputSize,
outputLocations[0][i][0] * inputSize,
outputLocations[0][i][3] * inputSize,
outputLocations[0][i][2] * inputSize);
recognitions.add(
new Recognition(
"" + i, labels.get((int) outputClasses[0][i]), outputScores[0][i], detection));
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {}
@Override
public String getStatString() {
return "";
}
@Override
public void close() {
if (tfLite != null) {
tfLite.close();
tfLite = null;
}
}
@Override
public void setNumThreads(int numThreads) {
if (tfLite != null) {
tfLiteOptions.setNumThreads(numThreads);
recreateInterpreter();
}
}
@Override
public void setUseNNAPI(boolean isChecked) {
if (tfLite != null) {
tfLiteOptions.setUseNNAPI(isChecked);
recreateInterpreter();
}
}
private void recreateInterpreter() {
tfLite.close();
tfLite = new Interpreter(tfLiteModel, tfLiteOptions);
}
}
当模型完全量化时,输出也是字节大小。您正在尝试将输出加载到浮点数组中:
将它们更改为
byte
,事情应该会正常进行。检查是否需要正确地对输出进行去量化改进点:使用tflite support library:这将简化前后处理、量化-反量化、输入-输出数据管理
相关问题 更多 >
编程相关推荐