在Java中使用经过训练的TensorFlow模型
我已经用Python训练了一个TensorFlow模型,并希望在Java代码中使用它。通过以下代码对模型进行培训:
def input_fn():
features = {'a': tf.constant([[1],[2]]),
'b': tf.constant([[3],[4]]) }
labels = tf.constant([0, 1])
return features, labels
feature_a = tf.contrib.layers.sparse_column_with_integerized_feature("a", bucket_size=10)
feature_b = tf.contrib.layers.sparse_column_with_integerized_feature("b", bucket_size=10)
feature_columns = [feature_a, feature_b]
model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns)
model.fit(input_fn=input_fn, steps=10)
现在我想保存这个模型,以便在Java中使用它。似乎export_savedmodel是新的/首选的储蓄方式,所以我尝试了:
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
model.export_savedmodel('export', serving_input_fn, as_text=True)
这将生成一个保存的模型,可以使用
model = SavedModelBundle.load(dir, "serve");
model.session().runner()
.feed("input_example_tensor", input)
.fetch("linear/binary_logistic_head/predictions/probabilities")
.run();
但现在有一个问题:输入_example _tensor应该是一个包含字符串/字节[]s的张量,但Java还不支持它(请参阅:tensor.Java#88“throw new UnsupportedOperationException”)。据我所知,它需要字符串的原因是build_parsing_serving_input_fn想要解析序列化的示例协议缓冲区
也许换一种服务会更好input_fn_utils.build_default_serving_input_fn
看起来很有希望,但我没能成功
如果我这样称呼它:
features_dict = {'a':feature_a, 'b':feature_b}
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)
我得到了“AttributeError:'_SparseColumnIntegrated'对象没有属性'get_shape'”
如果我这样称呼它:
features = {'a': tf.constant([[1],[2]]),
'b': tf.constant([[3],[4]]) }
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)
我得到“ValueError:'Const:0'不是有效的作用域名称”
使用input_fn_utils.build_default_serving_input_fn
的正确方法是什么?我找不到任何使用它的例子
共 (0) 个答案