我试图在n
点三维坐标和b
批中提取特定的行。本质上,我的张量T1
的形状b*n*3
。我还有另一个布尔张量T2
的形状b * n
,它指示了n
的哪些行需要被接受。
本质上,我的输出应该是b*?*3
,因为T2
在每个批中可以有不同数量的1
我已经使用布尔掩码实现了以下内容,但是输出不是预期的,输出形状是(?,)
,而不是(b*?*3)
# expand T2 to (b,n,3). i.e. 0 replicates to (0,0,0) and so is 1
mask = tf.tile(tf.expand_dims(T2,2), [1,1,3])
# query using boolean mask where there are 1s
valid_KPs = tf.boolean_mask(T1, tf.cast(mask, tf.int32))
由于每个例子中所选元素的数量可能不同,因此不能用适当的张量来表示。一种方法是使用ragged tensor。它们不能做正常张量所能做的一切,但可以实现你想要的,例如:
相关问题 更多 >
编程相关推荐