用TensorFlow/if-then在TensorF中编写分段函数

2024-04-25 08:43:33 发布

您现在位置:Python中文网/ 问答频道 /正文

如何编写一个分段张量流函数,即一个内部有if语句的函数?在

现行代码

import tensorflow as tf

my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
  x = tf.Variable(tf.random_normal([100, 1]))
  output = tf.map_fn(my_fn, x)

错误:

TypeError:不允许使用tf.Tensor作为Python bool。使用if t is not None:代替if t:来测试是否定义了张量,并使用逻辑TensorFlow ops来测试张量的值。在


Tags: lambda函数代码importifmytftensorflow
3条回答

你应该看看^{}。在

例如,您可以:

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)

编辑:从tf.select移动到tf.where

tf.select也不再工作,如此线程所示 https://github.com/tensorflow/tensorflow/issues/8647

对我有用的东西是tf.where

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)

这里的问题是,my_fn无法检查条件x>0,因为x是一个tf.Tensor,这意味着只有在启动tensorflow会话并请求运行包含x的图的一部分时,它才会填充值。要在图中包含if-then逻辑,必须使用tensorflow提供的操作,例如^{}。在

相关问题 更多 >