在tensorflow中,如何对top_k进行操作并创建一个与原始张量和修改后的top_k的新张量

0 投票
1 回答
32 浏览
提问于 2025-04-14 15:50

我一直在尝试做一件很简单的事情,但一直没有成功。我有一个张量(比如说 X,它的形状是 (None, 128)),里面包含了一些分数,也就是说每一批数据有128个分数。现在我用 Y = tf.math.top_k(X, k=a).indices 来找出前 a 个分数。为了简单起见,我们假设 a = 95。这样,张量 Y 的形状就会变成 (None, 95)。到这里为止,一切都还好。

现在我的原始 data 张量的形状是 (None, 3969, 128)。我想对那些拥有前 k 分数的数据进行一些操作。所以我用以下代码提取了这些数据:

ti = tf.reshape(Y, [Y.shape[-1], -1])   # Here ti is of shape (95, None)
fs = tf.gather(X, ti[:, 0], axis=-1)    # Here fs is of shape (None, 3969, 95)

然后我进行了操作,比如 Z = fs * 0.7 # 这里 Z 的形状是 (None, 3969, 95)。这也没问题。

现在我想创建一个新的张量 F,这个张量首先要保持形状 (None, 3969, 128),里面包含所有未改变的数据(那些分数不在前 k 的数据)和已修改的数据(那些分数在前 k 的数据,并且在 Z 中被修改过)。但是,这些数据的顺序要和原始数据保持一致,也就是说,修改过的数据应该仍然在它们原来的位置上。我在这里卡住了

我对 TensorFlow 还比较陌生,所以如果我漏掉了什么简单的东西或者表达不清楚,请见谅。我已经在这上面卡了好几天了。

谢谢!

1 个回答

1

一种方法是使用 tf.tensor_scatter_nd_update。不过,你需要把从topk函数得到的索引转换成适合你数据的格式。为此,你可以结合使用tf.tile和tf.unravel_index,把X的形状转换成你的data的形状。

假设你的数据是三维的,下面的代码可以给你一个类似的思路:

# getting the dimensions of the data tensor
# assuming that the shape of X is (B,N) and the shape of data is (B,D,N)
B, D, N = tf.unstack(tf.shape(data))
topk = tf.math.top_k(X, k=k)
# to get the absolute indices in the original tensor, we need to:
# - tile to get indices from (B,N) to (B,D,N)
# - do index arithmetics to get indices on the flattened tensor
topk_idx_tiled = tf.tile(topk.indices[:,None,:], [1,D,1])
flattened_indices = tf.reshape(tf.reshape(tf.range(B*D)*N,(B,D))[...,None] + topk_idx_tiled, -1)
# unraveling to get indices with batch dimensions so that we have compatibility with scatter_nd
sc_idx = tf.transpose(tf.unravel_index(flattened_indices, tf.shape(data)))
# scattering the updates to update the original data 
updates = tf.reshape(tf.tile(topk.values[:,None,:],[1,D,1]),-1)*0.7
F = tf.tensor_scatter_nd_update(data, sc_idx, updates)

撰写回答