有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

安卓如何通过TensorFlowEnferenceInterface提供布尔占位符。JAVA

我正试图通过Java Tensorflow API启动经过培训的Keras Tensorflow图形
除了标准的输入图像占位符外,此图还包含keras_learning_phase'占位符,需要为其输入布尔值

问题是,TensorFlowInferenceInterface中没有用于布尔值的方法-只能为其提供浮点值双精度整数字节

显然,当我试图通过这个代码将int传递给这个张量时:

inferenceInterface.fillNodeInt("keras_learning_phase",  
                               new int[]{1}, new int[]{0});

我明白了

tensorflow_inference_jni.cc:207 Error during inference: Internal: Output 0 of type int32 does not match declared output type bool for node _recv_keras_learning_phase_0 = _Recvclient_terminated=true, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=4742451733276497694, tensor_name="keras_learning_phase", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/cpu:0"

有没有办法绕过它
也许有可能以某种方式将图中的占位符节点显式转换为常量
或者一开始可以避免在图中创建这个占位符


共 (1) 个答案

  1. # 1 楼答案

    TensorFlowInferenceInterface类本质上是完整的TensorFlow Java API上的一个方便的包装器,它确实支持布尔值

    你也许可以在TensorFlowInferenceInterface中添加一个方法来做你想做的事情。与^{}类似,您可以添加以下内容(请注意,TensorFlow中的布尔表示为一个字节):

    public void fillNodeBool(String inputName, int[] dims, bool[] src) {
      byte[] b = new byte[src.length];
      for (int i = 0; i < src.length; ++i) {
        b[i] = src[i] ? 1 : 0;
      }
      addFeed(inputName, Tensor.create(DatType.BOOL, mkDims(dims), ByteBuffer.wrap(b)));
    }
    

    希望有帮助。如果它有效,我鼓励你回馈TensorFlow代码库