在tensorflow中,如何对top_k进行操作并创建一个与原始张量和修改后的top_k的新张量
我一直在尝试做一件很简单的事情,但一直没有成功。我有一个张量(比如说 X
,它的形状是 (None, 128)
),里面包含了一些分数,也就是说每一批数据有128个分数。现在我用 Y = tf.math.top_k(X, k=a).indices
来找出前 a
个分数。为了简单起见,我们假设 a = 95
。这样,张量 Y
的形状就会变成 (None, 95)
。到这里为止,一切都还好。
现在我的原始 data
张量的形状是 (None, 3969, 128)
。我想对那些拥有前 k
分数的数据进行一些操作。所以我用以下代码提取了这些数据:
ti = tf.reshape(Y, [Y.shape[-1], -1]) # Here ti is of shape (95, None)
fs = tf.gather(X, ti[:, 0], axis=-1) # Here fs is of shape (None, 3969, 95)
然后我进行了操作,比如 Z = fs * 0.7 # 这里 Z 的形状是 (None, 3969, 95)
。这也没问题。
现在我想创建一个新的张量 F
,这个张量首先要保持形状 (None, 3969, 128)
,里面包含所有未改变的数据(那些分数不在前 k
的数据)和已修改的数据(那些分数在前 k
的数据,并且在 Z
中被修改过)。但是,这些数据的顺序要和原始数据保持一致,也就是说,修改过的数据应该仍然在它们原来的位置上。我在这里卡住了。
我对 TensorFlow 还比较陌生,所以如果我漏掉了什么简单的东西或者表达不清楚,请见谅。我已经在这上面卡了好几天了。
谢谢!
1 个回答
1
一种方法是使用 tf.tensor_scatter_nd_update
。不过,你需要把从topk函数得到的索引转换成适合你数据的格式。为此,你可以结合使用tf.tile和tf.unravel_index,把X
的形状转换成你的data
的形状。
假设你的数据是三维的,下面的代码可以给你一个类似的思路:
# getting the dimensions of the data tensor
# assuming that the shape of X is (B,N) and the shape of data is (B,D,N)
B, D, N = tf.unstack(tf.shape(data))
topk = tf.math.top_k(X, k=k)
# to get the absolute indices in the original tensor, we need to:
# - tile to get indices from (B,N) to (B,D,N)
# - do index arithmetics to get indices on the flattened tensor
topk_idx_tiled = tf.tile(topk.indices[:,None,:], [1,D,1])
flattened_indices = tf.reshape(tf.reshape(tf.range(B*D)*N,(B,D))[...,None] + topk_idx_tiled, -1)
# unraveling to get indices with batch dimensions so that we have compatibility with scatter_nd
sc_idx = tf.transpose(tf.unravel_index(flattened_indices, tf.shape(data)))
# scattering the updates to update the original data
updates = tf.reshape(tf.tile(topk.values[:,None,:],[1,D,1]),-1)*0.7
F = tf.tensor_scatter_nd_update(data, sc_idx, updates)