如何使用地图迭代一个张量并在每次迭代中返回不同维度的值?

2024-05-13 12:11:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我想迭代一个张量,而不是使用我必须使用的急切执行地图(). 我想做的事情如下:

import tensorflow as tf

list_of_values = tf.constant([[1, 9, 65, 43], [8, 23, 21, 48], [11, 14, 98, 21], [98, 12, 32, 12]])

def value_finder(i):
  def f1():
    # Some computation with a local variable 'a' occurs
    # . . .
    return a  # a = [[3, 4, 5]]

  def f2():
    # Some computation with a local variable 'b' occurs
    # . . .
    return b  # b = [[7, 1, 2], [9, 3, 11]]

  return tf.cond(tf.reduce_all(tf.less(tf.slice(i, [1], [1]), tf.constant(18))), f1, f2)

value_obtained = tf.map_fn(lambda i: value_finder(i), list_of_values))

值a和b的维数不同,因此每当我试图运行我的代码时都会出错。在我的例子中,返回的值不可避免地具有不均匀的维度。有没有其他方法可以迭代张量并得到结果,而不是通过填充值使它们具有相等的维数?在


Tags: ofreturnfindervaluelocaltfdefwith