我现在使用tensorflow(python)来训练我的模型,并希望使用tensorflow(java)在线推断结果。在
计算图有一个返回shape[1,16]结果的操作,张量中的每个元素都是一个字符串。现在我想把结果转换成整个字符串。在
拜佛,打电话给特布菲张量.writeTo在缓冲区中写入数据。但是当我解码最终的缓冲区时,它的标题中有一些意外的字符,我猜最后的字节可能包含一些张量元信息。在
Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");
结果是这样的:第一部分是一些错误的代码,最后一部分是正确的。在
^{pr2}$我想也许形状(1,16)的张量在字节中有元信息,但我不知道如何获取我需要的部分。在
有人知道如何在javatensorflow接口中将多维张量转换成java字符串吗?在
我找到了解决办法! 训练模特时,我打电话给tf.reduce连接在具有形状(1,16)的张量上得到一个标量。 当用java语言进行推理时,我只需获取标量节点,并调用张量字节值()获取张量字节。它将返回不带标题代码的纯结果。在
如果操作的结果具有形状
[1, 16]
,则表示它正在生成16个不同的字符串,而不是一个字符串。在Java中对多维字符串张量的支持是最近才添加的(github commit),在TensorFlow 1.3版及之前版本的预构建二进制文件中没有包含。您必须要么从源代码构建,要么等待TensorFlow 1.4发布。在
有了这个特性,您应该能够用如下方式解码
(1, 16)
形张量:如果您真的需要一个字符串,那么是的,您可以使用
tf.reduce_join
让模型将16个字符串合并为一个字符串,然后提取标量。在相关问题 更多 >
编程相关推荐