如何在Tensorflow张量中的每一行和每一列中选取顶部的q元素?

2024-04-24 15:15:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我在张量流中实现了一个特殊的损失函数。下面是一个特殊函数的numpy样式的代码,该函数选择最上面的q元素并屏蔽每行和每列中的其他元素。注意,An*n矩阵,q是小于n的整数。你知道吗

def thresh(A, q):
    A_ = A.copy()
    n = A_.shape[1]
    for i in range(n):
        A_[i, :][A_[i, :].argsort()[0:n - q]] = 0
        A_[:, i][A_[:, i].argsort()[0:n - q]] = 0
    return A_

现在的问题是,我有一个张量流张量A,它的形状是(n,n),我想实现与numpy相同的逻辑。但是,我不能使用索引直接为张量A赋值。有什么解决办法吗?你知道吗


Tags: 函数代码numpy元素fordef矩阵样式
1条回答
网友
1楼 · 发布于 2024-04-24 15:15:41

TLDR;

我们可以创建一个函数来屏蔽除顶层k元素以外的所有元素,如下所示:

def mask_all_but_top_k(X, k):
  n = X.shape[1]
  top_k_indices = tf.math.top_k(X, k).indices
  mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
  return mask * X

不幸的是tf.map.top_k不允许我们指定维度,但是我们当然可以通过首先转置X然后转置结果来按列复制这个维度

说明

我们可以通过创建一个由1和0组成的掩码,然后按元素相乘来实现。你知道吗

例如,考虑到n=4, k=2的情况,我们有以下矩阵:

array([[0.67757607, 0.74070597, 0.89508283, 0.11858773],
       [0.7661159 , 0.8737055 , 0.73599136, 0.1552105 ],
       [0.7093129 , 0.44203556, 0.48861897, 0.83231044],
       [0.24682868, 0.36648738, 0.92984104, 0.9881872 ]], dtype=float32)

然后我们可以使用tf.math.top_k函数来获得矩阵每行中前2个值的索引:

top_k_indices = tf.math.top_k(X, 2).indices

现在,我们使用一个小技巧首先one_hot编码这些:

tf.one_hot(top_k_indices, 4)
array([[[0., 0., 1., 0.],
        [0., 1., 0., 0.]],

       [[0., 1., 0., 0.],
        [1., 0., 0., 0.]],

       [[0., 0., 0., 1.],
        [1., 0., 0., 0.]],

       [[0., 0., 0., 1.],
        [0., 0., 1., 0.]]], dtype=float32)>

然后reduce_sum跨越第二到最后一个维度来创建我们的掩码:

tf.reduce_sum(tf.one_hot(top_k_indices, 4), axis=1)
array([[0., 1., 1., 0.],
       [1., 1., 0., 0.],
       [1., 0., 0., 1.],
       [0., 0., 1., 1.]], dtype=float32)>

现在我们只需进行Hadamard(元素)乘法即可得到所需的结果:

array([[0.        , 0.74070597, 0.89508283, 0.        ],
       [0.7661159 , 0.8737055 , 0.        , 0.        ],
       [0.7093129 , 0.        , 0.        , 0.83231044],
       [0.        , 0.        , 0.92984104, 0.9881872 ]], dtype=float32)>

把这些放在一起,我们可以创建一个函数,它按行屏蔽除顶部k元素以外的所有元素,如下所示:

def mask_all_but_top_k(X, k):
  n = X.shape[1]
  top_k_indices = tf.math.top_k(X, k).indices
  mask = tf.reduce_sum(tf.one_hot(top_k_indices, n), axis=1)
  return mask * X

相关问题 更多 >