张量流稀疏传感器的有效布尔掩蔽

2024-04-27 03:54:13 发布

您现在位置:Python中文网/ 问答频道 /正文

所以,我想屏蔽SparseTensor的整行。用tf.boolean_mask很容易做到这一点,但是SparseTensor没有等价的方法。目前,我有可能只检查SparseTensor.indices中的所有索引,然后过滤掉所有不是屏蔽行的索引,例如:

masked_indices = list(filter(lambda index: masked_rows[index[0]], indices))

其中masked\u rows是一个1D数组,表示该索引处的行是否被屏蔽。你知道吗

然而,这真的很慢,因为我的SparseTensor相当大(它有90k个索引,但会越来越大)。在我对过滤的索引应用SparseTensor.mask之前,在单个数据点上花费了相当多的时间。这种方法的另一个缺陷是,它实际上也没有删除行(尽管在我的例子中,一个全零的行也一样好)。你知道吗

有没有更好的方法来屏蔽一个行稀疏传感器,还是这是最好的方法?你知道吗


Tags: 方法lambdaindextfmask数组filterlist
1条回答
网友
1楼 · 发布于 2024-04-27 03:54:13

你可以这样做:

import tensorflow as tf

def boolean_mask_sparse_1d(sparse_tensor, mask, axis=0):  # mask is assumed to be 1D
    mask = tf.convert_to_tensor(mask)
    ind = sparse_tensor.indices[:, axis]
    mask_sp = tf.gather(mask, ind)
    new_size = tf.math.count_nonzero(mask)
    new_shape = tf.concat([sparse_tensor.shape[:axis], [new_size],
                           sparse_tensor.shape[axis + 1:]], axis=0)
    new_shape = tf.dtypes.cast(new_shape, tf.int64)
    mask_count = tf.cumsum(tf.dtypes.cast(mask, tf.int64), exclusive=True)
    masked_idx = tf.boolean_mask(sparse_tensor.indices, mask_sp)
    new_idx_axis = tf.gather(mask_count, masked_idx[:, axis])
    new_idx = tf.concat([masked_idx[:, :axis],
                         tf.expand_dims(new_idx_axis, 1),
                         masked_idx[:, axis + 1:]], axis=1)
    new_values = tf.boolean_mask(sparse_tensor.values, mask_sp)
    return tf.SparseTensor(new_idx, new_values, new_shape)

# Test
sp = tf.SparseTensor([[1], [3], [4], [6]], [1, 2, 3, 4], [7])
mask = tf.constant([True, False, True, True, False, False, True])
out = boolean_mask_sparse_1d(sp, mask)
print(out.indices.numpy())
# [[2]
#  [3]]
print(out.values.numpy())
# [2 4]
print(out.shape)
# (4,)

相关问题 更多 >