形状为b*n*3的张量T1。形状为b*n的T2>一个布尔张量,指示在T1中采用哪一行

2024-05-13 17:45:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在n点三维坐标和b批中提取特定的行。本质上,我的张量T1的形状b*n*3。我还有另一个布尔张量T2的形状b * n,它指示了n的哪些行需要被接受。 本质上,我的输出应该是b*?*3,因为T2在每个批中可以有不同数量的1

我已经使用布尔掩码实现了以下内容,但是输出不是预期的,输出形状是(?,),而不是(b*?*3)

# expand T2 to (b,n,3). i.e. 0 replicates to (0,0,0) and so is 1

mask = tf.tile(tf.expand_dims(T2,2), [1,1,3])

# query using boolean mask where there are 1s

valid_KPs = tf.boolean_mask(T1, tf.cast(mask, tf.int32))

Tags: andto数量soistfmaskexpand
1条回答
网友
1楼 · 发布于 2024-05-13 17:45:52

由于每个例子中所选元素的数量可能不同,因此不能用适当的张量来表示。一种方法是使用ragged tensor。它们不能做正常张量所能做的一切,但可以实现你想要的,例如:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    # Input data
    t1 = tf.constant([
        [
            [ 1,  2,  3],
            [ 4,  5,  6],
            [ 7,  8,  9],
        ],
        [
            [10, 11, 12],
            [13, 14, 15],
            [16, 17, 18],
        ],
    ])
    t2 = tf.constant([
        [1, 0, 1],
        [0, 1, 0],
    ])
    # Count the number of ones for each row in T2
    c = tf.reduce_sum(t2, axis=1)
    # Ragged ranges for each row
    r = tf.ragged.range(c)
    # Sorting indices so indices with a one are first
    s = tf.argsort(t2, axis=1, direction='DESCENDING', stable=True)
    # First axis dimension index
    idx0 = tf.expand_dims(tf.range(tf.shape(t1)[0]), 1) * tf.ones_like(r)
    # 2D index for getting indices of ones on each row
    idx_s = tf.stack([idx0, r], axis=-1)
    # Get indices of ones
    idx1 = tf.gather_nd(s, idx_s)
    # 2D index to get indices of selected vectors in T1
    idx = tf.stack([idx0, idx1], axis=-1)
    # Get selected vectors
    result = tf.gather_nd(t1, idx)
    # Print result
    print(sess.run(result))
    # <tf.RaggedTensorValue [[[1, 2, 3], [7, 8, 9]], [[13, 14, 15]]]>

相关问题 更多 >