理解CTC如何实现TF工作

2024-06-17 15:01:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图理解CTC实现在TensorFlow中是如何工作的。我已经写了一个简单的例子来测试CTC函数,但出于某种原因,我正在为一些目标/输入值指定inf,我确定为什么会发生这种情况!?在

代码:

import tensorflow as tf
import numpy as np

# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
    """Create a sparse representention of x.
    Args:
        sequences: a list of lists of type dtype where each element is a sequence
    Returns:
        A tuple with (indices, values, shape)
    """
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

batch_size = 1
seq_length = 2
n_labels = 2

seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))


with tf.Session() as sess:
    for it in range(10):
        rand_target = np.random.randint(n_labels, size=(seq_length))
        sample_target = sparse_tuple_from([rand_target])

        logitsval = sess.run(logits)
        lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
        print('******* Iter: %d *******'%it)
        print('logits:', logitsval)
        print('rand_target:', rand_target)
        print('rand_sparse_target:', sample_target)
        print('loss:', lossval)
        print()

样本输出:

^{pr2}$

我有什么想法吗!?在


Tags: targetlentfnplengthseqsparsevalues
1条回答
网友
1楼 · 发布于 2024-06-17 15:01:45

仔细看看你的输入文本(rand_target),我相信你会看到一些与inf丢失值相关的简单模式;-)

对正在发生的事情的简短解释: CTC通过允许每个字符重复对文本进行编码,还允许在字符之间插入非字符标记(称为“CTC空白标签”)。撤销这种编码(或解码)就意味着扔掉重复的字符,然后扔掉所有的空格。 给出一些例子(“…”)对应于文本、“……”来对空白标签进行编码和“-'”:

  • “to”->;“tttooo”或“t-o”或“t-oo”或“to”等。。。在
  • “too”->;“to-o”,或“tttoo-oo”,或“-t-o-o”,但不是“too”(想想解码后的“too”是什么样子)

现在我们知道了为什么你们的一些样品失败了:

  • 输入文本的长度为2
  • 编码长度为2
  • 如果输入字符是重复的(例如'11',或者是python列表:[1,1]),那么编码的唯一方法就是在中间加上一个空格(想想看丰富的解码'11'和'1-1')。但是编码的长度是3。在
  • 因此,无法将长度为2且具有重复字符的文本编码为长度2编码,因此TF loss实现返回inf

您也可以将编码想象为一个状态机-参见下面的插图。文本“11”可以用从开始状态(两个最左边的状态)开始到最终状态(两个最右边的状态)结束的所有可能路径来表示。如你所见,最短的可能路径是'1-1'。在

enter image description here

最后,您必须为输入文本中的每个重复字符至少添加一个空白。 也许本文有助于理解CTC:https://towardsdatascience.com/3797e43a86c

相关问题 更多 >