我在张量流中实现了一个特殊的损失函数。下面是一个特殊函数的numpy样式的代码,该函数选择最上面的q元素并屏蔽每行和每列中的其他元素。注意,A
是n*n
矩阵,q
是小于n
的整数。你知道吗
def thresh(A, q):
A_ = A.copy()
n = A_.shape[1]
for i in range(n):
A_[i, :][A_[i, :].argsort()[0:n - q]] = 0
A_[:, i][A_[:, i].argsort()[0:n - q]] = 0
return A_
现在的问题是,我有一个张量流张量A
,它的形状是(n,n)
,我想实现与numpy相同的逻辑。但是,我不能使用索引直接为张量A
赋值。有什么解决办法吗?你知道吗
TLDR;
我们可以创建一个函数来屏蔽除顶层
k
元素以外的所有元素,如下所示:不幸的是
tf.map.top_k
不允许我们指定维度,但是我们当然可以通过首先转置X
然后转置结果来按列复制这个维度说明
我们可以通过创建一个由1和0组成的掩码,然后按元素相乘来实现。你知道吗
例如,考虑到
n=4, k=2
的情况,我们有以下矩阵:然后我们可以使用
tf.math.top_k
函数来获得矩阵每行中前2个值的索引:现在,我们使用一个小技巧首先
one_hot
编码这些:然后
reduce_sum
跨越第二到最后一个维度来创建我们的掩码:现在我们只需进行Hadamard(元素)乘法即可得到所需的结果:
把这些放在一起,我们可以创建一个函数,它按行屏蔽除顶部
k
元素以外的所有元素,如下所示:相关问题 更多 >
编程相关推荐