tensorflow map_fn TensorArray 不一致的形状

2024-04-29 03:20:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我在玩map_fn函数,注意到它输出了一个张量阵列,这意味着它能够输出“锯齿状”张量(其中内部的张量具有不同的第一维度)。在

我试着用这段代码来证明这一点:

import tensorflow as tf
import numpy as np

NUM_ARRAYS = 1000
MAX_LENGTH = 1000

lengths = tf.placeholder(tf.int32)
tArray = tf.map_fn(lambda x: tf.random_normal((x,), 0, 1),
                   lengths,
                   dtype=tf.float32) # Should return a TensorArray.

# startTensor =  tf.random_normal((tf.reduce_sum(lengths),), 0, 1)
# tArray = tf.TensorArray(tf.float32, NUM_ARRAYS)
# tArray = tArray.split(startTensor, lengths)
# outArray = tArray.concat()


with tf.Session() as sess:
    outputArray, l = sess.run(
        [tArray, lengths],
        feed_dict={lengths: np.random.randint(MAX_LENGTH, size=NUM_ARRAYS)})
    print outputArray.shape, l

但是,得到的错误是:

张量阵列的形状不一致。索引0的形状为:[259],但索引1的形状为:[773]”

这当然让我感到惊讶,因为我的印象是张量阵列应该能够处理它。我错了?在


Tags: importmaptfasnprandomlengthnum
1条回答
网友
1楼 · 发布于 2024-04-29 03:20:23

虽然^{}确实在内部使用^{}对象,并且tf.TensorArray可以容纳不同大小的对象,但是这个程序不会正常工作,因为tf.map_fn()通过将元素堆叠在一起,将其tf.TensorArray结果转换回tf.Tensor,而这个操作失败了。在

但是,您可以使用较低的^{}操作来实现基于tf.TensorArray的操作:

lengths = tf.placeholder(tf.int32)
num_elems = tf.shape(lengths)[0]
init_array = tf.TensorArray(tf.float32, size=num_elems)

def loop_body(i, ta):
  return i + 1, ta.write(i, tf.random_normal((lengths[i],), 0, 1))

_, result_array = tf.while_loop(
    lambda i, ta: i < num_elems, loop_body, [0, init_array])

相关问题 更多 >