在TensorF中使用预先训练的单词嵌入(word2vec或Glove)

2024-04-28 15:07:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我最近回顾了convolutional text classification的一个有趣的实现。然而,我回顾过的所有TensorFlow代码都使用随机(非预训练)嵌入向量,如下所示:

with tf.device('/cpu:0'), tf.name_scope("embedding"):
    W = tf.Variable(
        tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
        name="W")
    self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)
    self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

有没有人知道如何使用Word2vec的结果,或者一个手套预先训练的单词嵌入,而不是一个随机的?


Tags: 代码textnameselfsizedevicetftensorflow
3条回答

我使用此方法加载和共享嵌入。

W = tf.get_variable(name="W", shape=embedding.shape, initializer=tf.constant_initializer(embedding), trainable=False)

有几种方法可以在TensorFlow中使用预先训练的嵌入。假设您在一个名为embedding的NumPy数组中嵌入了vocab_size行和embedding_dim列,并且您希望创建一个可以在调用^{}时使用的张量W

  1. 只需将W创建为以embedding为其值的^{}

    W = tf.constant(embedding, name="W")
    

    这是最简单的方法,但它不具有内存效率,因为tf.constant()的值多次存储在内存中。因为embedding可能非常大,所以应该只对玩具示例使用此方法。

  2. 创建W作为tf.Variable,并通过^{}从NumPy数组初始化它:

    W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]),
                    trainable=False, name="W")
    
    embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
    embedding_init = W.assign(embedding_placeholder)
    
    # ...
    sess = tf.Session()
    
    sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})
    

    这样可以避免在图中存储embedding的副本,但它确实需要足够的内存来同时在内存中保留矩阵的两个副本(一个用于NumPy数组,另一个用于tf.Variable)。请注意,我假设您希望在训练期间保持嵌入矩阵常量,所以W是用trainable=False创建的。

  3. 如果嵌入被训练为另一个TensorFlow模型的一部分,则可以使用^{}从另一个模型的检查点文件加载值。这意味着嵌入矩阵可以完全绕过Python。创建选项2中的W,然后执行以下操作:

    W = tf.Variable(...)
    
    embedding_saver = tf.train.Saver({"name_of_variable_in_other_model": W})
    
    # ...
    sess = tf.Session()
    embedding_saver.restore(sess, "checkpoint_filename.ckpt")
    

@mrry的答案是不正确的,因为它证明了每次网络运行时都会覆盖嵌入的权重,所以如果您采用小批量方法来训练网络,那么您就是覆盖了嵌入的权重。所以,在我看来,正确的预训练嵌入方法是:

embeddings = tf.get_variable("embeddings", shape=[dim1, dim2], initializer=tf.constant_initializer(np.array(embeddings_matrix))

相关问题 更多 >