我有一个data
维张量[B X N X 3]
,我有一个indices
维张量[B X M]
。我想用indices
张量从data
张量中提取一个[B X M X 3]
张量。在
我有一个有效的代码:
new_data= []
for i in range(B):
new_data.append(tf.gather(data[i], indices[i]))
new_data= tf.stack(new_data)
然而,我确信这不是正确的方法。有人知道更好的方法吗?(我想我应该以某种方式使用tf.gather_nd()
,但我不知道怎么用)
我看到了一些类似问题的答案here。但是我找不到解决问题的办法。在
您可以将
tf.gather_nd()
与以下代码一起使用:
^{pr2}$new_data
将是:相关问题 更多 >
编程相关推荐