tf.s的用途是什么

2024-04-23 20:02:55 发布

您现在位置:Python中文网/ 问答频道 /正文

已编辑w.r.t.@quirk的答案)

我在网上读到一些tensorflow代码,看到了下面的语句:

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

来源:https://github.com/Raverss/tensorflow-RLSA-NMS/blob/master/source.py#L31

positive是一个只有1的张量,negative也是一个和0大小相同的张量,输入是一些相同大小的热图(/张量)(都是tf.float32类型)。

对于我来说,代码片段似乎相当高级,可以假定,如果没有特定的原因使用tf.select(...)表达式,作者将只使用tf.cast(input > RLSA_THRESHOLD, tf.float32)。特别是因为这将消除对变量positivenegative的需要,并且将节省内存,因为它们只是存储01的昂贵冗余方式。

前面提到的tf.select(...)表达式是否等同于tf.cast(input > RLSA_THRESHOLD, tf.float32)?如果没有,为什么不呢?

注意:我通常使用Keras,如果我在这里碰到一些非常琐碎的东西,我很抱歉。


Tags: 答案代码编辑inputthreshold表达式tftensorflow
2条回答

嗯,RTD(阅读文档)!

tf.select根据condition张量中元素的boolnesspositivenegative张量中选择元素。

tf.select(condition, t, e, name=None)
Selects elements from t or e, depending on condition.
The t, and e tensors must all have the same shape, and the output will also have that shape.

(摘自官方文件。)

所以在你的情况下:

threshold = tf.select(input > RLSA_THRESHOLD, positive, negative)

input > RLSA_THRESHOLD将是bool或逻辑值(01)的张量,这将有助于从positive向量或negative向量中选择值。

例如,假设你有一个0.5的RLSA_THRESHOLD,而你的input向量是一个从0到1的实数连续值的四维向量。你的positivenegative向量本质上分别是[1, 1, 1, 1][0, 0, 0, 0]input[0.8, 0.2, 0.5, 0.6]

threshold将是[1, 0, 0, 1]

注:positivenegative可以是任何类型的张量,只要维数与condition张量一致。如果positivenegative分别是[2, 4, 6, 8][1, 3, 5, 7],那么threshold就是[2, 3, 5, 8]


The code snippet seems reasonably advanced for me to assume that the authors would have just used input > RLSA_THRESHOLD if there was no specific reason for the tf.select.

有一个很好的理由。input > RLSA_THRESHOLD只需返回逻辑(布尔)值的张量。逻辑值不能与数值很好地混合。不能将它们用于任何实际的数值计算。如果positive和/或negative张量是实数,那么您可能需要threshold张量也有实数,以防您打算进一步使用它们。


Is the tf.select equivalent to input > RLSA_THRESHOLD? If not, why not?

不,他们不是。一个是函数,另一个是张量。

我要给你一个怀疑的好处,假设你想问:

Is the threshold equivalent to input > RLSA_THRESHOLD? If not, why not?

不,他们不是。如上所述,input > RLSA_THRESHOLD是数据类型为bool的逻辑张量。threshold,另一方面,是与positivenegative具有相同数据类型的张量。

注意:您始终可以使用中可用的任何casting方法将逻辑张量转换为数值张量(或任何其他受支持的数据类型)。

最好的理解方法是亲自尝试:

In [86]: s = tf.InteractiveSession()

In [87]: inputs = tf.random_uniform([10], 0., 1.)

In [88]: positives = tf.ones([10])

In [89]: negatives = tf.zeros([10])    

In [90]: s.run([inputs, tf.select(inputs > .5, positives, negatives)])
Out[90]: 
[array([ 0.13187623,  0.77344072,  0.29853749,  0.29245567,  0.53489852,
         0.34861541,  0.15090156,  0.40595055,  0.34910154,  0.24349082], dtype=float32),
 array([ 0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.], dtype=float32)]

对于张量中的每个值>;0.5,您将在同一索引处得到一个1.,否则该值为0.

inputs > .5的结果是布尔张量(满足条件的值为True,否则为False)。

In [92]: s.run(inputs > .5)
Out[92]: array([ True, False,  True,  True,  True,  True,  True,  True, False,  True], dtype=bool)

相关问题 更多 >