关于张量流中几个元素的指数的求法

2024-04-24 11:06:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我是Tensorflow的新手。你知道吗

我有一个问题。你知道吗

这里有1d阵列。你知道吗

 values = [101,103,105,109,107]

 target_values = [105, 103]

我想马上从values得到一个关于target_values的索引。你知道吗

下面将显示从上述示例中提取的索引。你知道吗

indices = [2, 1]

当我使用tf.map_fn函数时。 这个问题很容易解决。你知道吗

# if you do not change data type from int64 to int32. TypeError will riase
values = tf.cast(tf.constant([100, 101, 102, 103, 104]), tf.int64)
target_values = tf.cast(tf.constant([100, 101]), tf.int64)
indices = tf.map_fn(lambda x: tf.where(tf.equal(values, x)), target_values)

谢谢你!你知道吗


Tags: 函数you示例maptargetiftftensorflow
1条回答
网友
1楼 · 发布于 2024-04-24 11:06:52

假设target_values中的所有值都在values中,这是一种简单的方法(tf2.x,但函数对1.x的作用应该相同):

import tensorflow as tf

values = [101, 103, 105, 109, 107]
target_values = [105, 103]

# Assumes all values in target_values are in values
def find_in_array(values, target_values):
    values = tf.convert_to_tensor(values)
    target_values = tf.convert_to_tensor(target_values)
    # stable=True if there may be repeated elements in values
    # and you want always first occurrence
    idx_s = tf.argsort(values, stable=True)
    values_s = tf.gather(values, idx_s)
    idx_search = tf.searchsorted(values_s, target_values)
    return tf.gather(idx_s, idx_search)

print(find_in_array(values, target_values).numpy())
# [2 1]

相关问题 更多 >