把获取的张量转换成字符串的java模型推理?

2024-04-29 12:51:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我现在使用tensorflow(python)来训练我的模型,并希望使用tensorflow(java)在线推断结果。在

计算图有一个返回shape[1,16]结果的操作,张量中的每个元素都是一个字符串。现在我想把结果转换成整个字符串。在

拜佛,打电话给特布菲张量.writeTo在缓冲区中写入数据。但是当我解码最终的缓冲区时,它的标题中有一些意外的字符,我猜最后的字节可能包含一些张量元信息。在

Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");

结果是这样的:第一部分是一些错误的代码,最后一部分是正确的。在

^{pr2}$

我想也许形状(1,16)的张量在字节中有元信息,但我不知道如何获取我需要的部分。在

有人知道如何在javatensorflow接口中将多维张量转换成java字符串吗?在


Tags: 字符串模型信息string字节tensorflowjava缓冲区
2条回答

我找到了解决办法! 训练模特时,我打电话给tf.reduce连接在具有形状(1,16)的张量上得到一个标量。 当用java语言进行推理时,我只需获取标量节点,并调用张量字节值()获取张量字节。它将返回不带标题代码的纯结果。在

如果操作的结果具有形状[1, 16],则表示它正在生成16个不同的字符串,而不是一个字符串。在

Java中对多维字符串张量的支持是最近才添加的(github commit),在TensorFlow 1.3版及之前版本的预构建二进制文件中没有包含。您必须要么从源代码构建,要么等待TensorFlow 1.4发布。在

有了这个特性,您应该能够用如下方式解码(1, 16)形张量:

Tensor predictedTensor = result.get(0);
byte[][][] predictedTokenBytes = predictedTensor.copyTo(new byte[1][16][]);
String[] predictedTokens = new String[16];
for (int i = 0; i < 16; ++i) {
  // This works under the assumption that the model is actually
  // producing UTF-8 strings    
  predictedTokens[i] = new String(predictedTokenBytes[0][i], "UTF-8");
}

如果您真的需要一个字符串,那么是的,您可以使用tf.reduce_join让模型将16个字符串合并为一个字符串,然后提取标量。在

相关问题 更多 >