如何正确导出TensorFlow WALS矩阵分解估计器的\u savedmodel()?

2024-05-16 04:27:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在张量流上使用thisWALS矩阵分解模块。拟合估计器后,我尝试使用export\u savedmodel()方法保存模型,但无法提供正确的参数。代码如下:

from tensorflow.contrib.factorization.python.ops import wals as wals_lib

# dense input array that shows the user X item interactions
# here set as dummy array
dense_array = np.ones((10,10))
num_rows, num_cols = dense_array.shape
emebedding_dim = 5 # manually setting hidden factor

factorizer = wals_lib.WALSMatrixFactorization(num_rows, num_cols, embedding_dim, max_sweeps=10)

# this generate_input_fn() is not shown here but it's a copy of
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/factorization/python/ops/wals_test.py#L82
input_fn, _, _ = generate_input_fn(np_matrix=dense_array, batch_size=32, mode=model_fn.ModeKeys.TRAIN)
factorizer.fit(input_fn, steps=10)

# MY PROBLEM IS HERE
# How to define the correct serving input function?
factorizer.export_savedmodel('path/to/save/model', serving_input_fn=???)

这里比较棘手的部分是,我相信WALS模块使用的是一个旧的TensorFlow范例,serving_input_fn参数需要一个返回InputFnOps的可调用函数。然而,更新得越多的估计器,如thisone,则需要返回tf.estimator.export.ServingInputReceivertf.estimator.export.TensorServingInputReceiver的函数。我承认我还不能完全流利地使用TensorFlow的输入函数,但是对于保存我的WALS估计器的具体用例的任何帮助都将不胜感激。谢谢!你知道吗


Tags: 模块input参数tensorflowexportcontribarraynum