冻结CNN模型的Java Tensorflow推理问题

2024-03-28 11:41:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我遇到了一些关于使用javatensorflowapi的问题。在

基本上,我尝试使用Python中训练的冻结模型来预测一些图像,但是我想用Java中的Tensorflow对一些应用程序进行推断,如果这些应用程序可以工作的话。在

我首先将Python模型导出为一个.pb文件,然后可以将其加载到Tensorflow中,它可以用于推理目的,我在Python中测试了它,它的工作没有任何问题。在

然后,我试图修改标签图像.java在GitHub(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java)上可以找到的Java Tensorflow示例中提供的示例。我基本上修改了模型的路径和我将使用的图像。在成功地纠正了一些错误之后,代码是可以运行的,但是我得到了一个错误:

Exception in thread "main" java.lang.UnsupportedOperationException: Generic conv implementation does not support grouped convolutions for now.
 [[{{node conv2d_1/convolution}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_1_0_0, conv2d_1/kernel)]]

一般来说,我对Java和Tensorflow很陌生,我试图找到类似的错误,比如我得到的错误,但没有发现任何有用的东西。我想知道这个错误是否是想告诉我当前的Tensorflow API for Java不支持卷积,如果是这样的话,我会非常惊讶。不管怎么说,我有点不知所措,我想我能做些什么来解决这个问题,我希望有人能帮我找到解决办法。在

一些细节:我在Keras上构建并训练了一个U-Net模型,并使用了GitHub上某个用户的方法,该方法将经过训练的Keras模型转换成一个.pb文件,该文件可以重新加载到Tensorflow上并运行以进行推理(用户:https://github.com/amir-abdi/keras_to_tensorflow)。这个重新加载和推理部分在Python中工作得很好(我测试了它以确定)。在

代码块中似乎发生了错误:

^{pr2}$

这段代码没有改变,就像我说的,我只是改变了指向我的模型和用于测试的示例图像的输入路径。我更改的确切部分可以在下面找到:

^{3}$

我希望这些信息足以理解这个问题。在


Tags: 文件代码https模型图像github应用程序示例
1条回答
网友
1楼 · 发布于 2024-03-28 11:41:59

好吧,最后我找到了我自己问题的答案。在

基本上,错误是因为我给模型输入的图像没有合适的大小(我的图像是512x512,我的模型只需要256x256个图像)。所以,我想问题是输入张量没有正确的维数。在

希望这篇文章能对有同样问题的人有所帮助。在

相关问题 更多 >