如何使用一个张量来索引另一个以十为单位的张量

2024-03-28 18:49:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个data维张量[B X N X 3],我有一个indices维张量[B X M]。我想用indices张量从data张量中提取一个[B X M X 3]张量。在

我有一个有效的代码:

new_data= []    
for i in range(B):
        new_data.append(tf.gather(data[i], indices[i]))
new_data= tf.stack(new_data) 

然而,我确信这不是正确的方法。有人知道更好的方法吗?(我想我应该以某种方式使用tf.gather_nd(),但我不知道怎么用)

我看到了一些类似问题的答案here。但是我找不到解决问题的办法。在


Tags: 方法答案代码innewfordatastack
1条回答
网友
1楼 · 发布于 2024-03-28 18:49:31

您可以将tf.gather_nd()与以下代码一起使用:

import tensorflow as tf

# B = 3
# N = 4
# M = 2
# [B x N x 3]
data = tf.constant([
    [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
    [[100, 101, 102], [103, 104, 105], [106, 107, 108], [109, 110, 111]],
    [[200, 201, 202], [203, 204, 205], [206, 207, 208], [209, 210, 211]],
    ])

# [B x M]
indices = tf.constant([
    [0, 2],
    [1, 3],
    [3, 2],
    ])

indices_shape = tf.shape(indices)

indices_help = tf.tile(tf.reshape(tf.range(indices_shape[0]), [indices_shape[0], 1]) ,[1, indices_shape[1]]);
indices_ext = tf.concat([tf.expand_dims(indices_help, 2), tf.expand_dims(indices, 2)], axis = 2)
new_data = tf.gather_nd(data, indices_ext)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('data')
    print(sess.run(data))
    print('\nindices')
    print(sess.run(indices))
    print('\nnew_data')
    print(sess.run(new_data))

new_data将是:

^{pr2}$

相关问题 更多 >