按参与者拆分tensorflow数据集

2024-05-29 04:13:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在属性上分割一个tf.data.DataSet,在我的例子中是参与者或手势。目前,该数据集正在开发中,参与者/手势的数量可能会增加。我最初为这个gestures-dataset设置了一个tfds config,但我也没有弄清楚如何在这里配置参与者/手势拆分

我应该如何分割tf.data.DataSet对象?目前,我的数据集作为单个tf_记录存在。我更愿意保持这种方式,而不是为每个参与者和手势生成不同的文件,并且在添加新参与者时必须重新生成所有手势tf记录

本工程(方法1,总量):

subset = ds.filter(lambda x: (x['participant'] == 1 or x['participant'] == 2))

这并不是(方法2,梦想):

subset = ds.filter(lambda x: any(x['participant'] == p for p in [1,2]))

OperatorNotAllowedInGraphError: using a tf.Tensor as a Python bool is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我还尝试了与修饰@tf.function相同的操作

具有公开可用mnist数据集的示例代码:juptyer notebook on colab

他们的操作方式是否与方法2类似?


Tags: 数据方法lambdadatatf方式记录ds
1条回答
网友
1楼 · 发布于 2024-05-29 04:13:21

您可以通过以下方式解决此问题:

def predicate(label, labels_to_filter):
  return tf.math.reduce_any(tf.equal(label, labels_to_filter))


labels_to_filter = tf.constant([0, 1, 2, 3, 4], dtype=tf.int64)
subset = dataset.filter(lambda x: predicate(x["participant"], labels_to_filter))
工作原理:

tf.equal返回一个布尔张量,如果labellabels_to_filter中的一个标签匹配,则该布尔张量包含Truetf.math.reduce_any如果其输入布尔张量中有True个值,则返回True

相关问题 更多 >

    热门问题