Python中文网

一个关于 编程问题的解答网站.

有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

在Java中使用经过训练的TensorFlow模型

我已经用Python训练了一个TensorFlow模型,并希望在Java代码中使用它。通过以下代码对模型进行培训:

def input_fn():
    features = {'a': tf.constant([[1],[2]]),
                'b': tf.constant([[3],[4]]) }
    labels = tf.constant([0, 1])
    return features, labels

feature_a = tf.contrib.layers.sparse_column_with_integerized_feature("a", bucket_size=10)
feature_b = tf.contrib.layers.sparse_column_with_integerized_feature("b", bucket_size=10)
feature_columns = [feature_a, feature_b]

model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns)
model.fit(input_fn=input_fn, steps=10)

现在我想保存这个模型,以便在Java中使用它。似乎export_savedmodel是新的/首选的储蓄方式,所以我尝试了:

feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
model.export_savedmodel('export', serving_input_fn, as_text=True)

这将生成一个保存的模型,可以使用

model = SavedModelBundle.load(dir, "serve");
model.session().runner()
    .feed("input_example_tensor", input)
    .fetch("linear/binary_logistic_head/predictions/probabilities")
    .run();

但现在有一个问题:输入_example _tensor应该是一个包含字符串/字节[]s的张量,但Java还不支持它(请参阅:tensor.Java#88“throw new UnsupportedOperationException”)。据我所知,它需要字符串的原因是build_parsing_serving_input_fn想要解析序列化的示例协议缓冲区

也许换一种服务会更好input_fn_utils.build_default_serving_input_fn看起来很有希望,但我没能成功

如果我这样称呼它:

features_dict = {'a':feature_a, 'b':feature_b}
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)

我得到了“AttributeError:'_SparseColumnIntegrated'对象没有属性'get_shape'”

如果我这样称呼它:

features = {'a': tf.constant([[1],[2]]),
            'b': tf.constant([[3],[4]]) }
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)

我得到“ValueError:'Const:0'不是有效的作用域名称”

使用input_fn_utils.build_default_serving_input_fn的正确方法是什么?我找不到任何使用它的例子


共 (0) 个答案