从GCMLE保存的Mod中提取嵌入

2024-03-28 12:42:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试从一个经过训练的GCMLE预测模型本地下载嵌入,这样我就可以使用自己定制的嵌入可视化,而这些可视化在tensorboard中是不可用的。我想把这些嵌入提取到一个大块头的numpy矩阵中,但是我在几个步骤上遇到了麻烦。我可以成功下载所有文件(saved_model.pb+assets/*+variables/*),并且我似乎能够使用以下代码恢复模型:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)

成功返回:

INFO:tensorflow:Restoring parameters from Servo/variables/variables

然后我试着像这样提取权重:

^{pr2}$

它确实成功地输出了很多,但是与嵌入相关的部分只有:

u'embedding_layer/embeddings/Initializer/random_uniform/max': 0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/min': -0.012765553,
u'embedding_layer/embeddings/Initializer/random_uniform/shape': array([vocab_size, word_embedding_size], dtype=int32)

没有实际嵌入权重的迹象。如何修改上述方法以获得实际的嵌入权重矩阵?在


Tags: 模型layermodel可视化tf矩阵randomuniform
1条回答
网友
1楼 · 发布于 2024-03-28 12:42:27

这在一定程度上取决于您如何导出模型,但在大多数情况下,嵌入是变量而不是常量。所以你想要这样的东西:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], _EXPORT_DIR)

    trainable_coll = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    vars = {v.name:sess.run(v.value()) for v in trainable_coll}

相关问题 更多 >