用tf.d批处理序列数据

2024-04-25 09:24:45 发布

您现在位置:Python中文网/ 问答频道 /正文

让我们考虑一个玩具数据集ordered,它有两个特性:

  • value(例如1, 2, 3, 4, 5, 111, 222, 333, 444, 555
  • sequence_id(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1

该数据基本上由两个串联的扁平序列组成,1, 2, 3, 4, 5(序列0)和{}(序列1)。在

我想生成大小为t(比如3)的序列,由来自同一序列(sequence_id)的连续元素组成,我不希望序列具有属于不同的sequence_id的元素。在

例如,在不进行任何洗牌的情况下,我希望获得以下批次:

  • 第一批:1, 2, 3
  • 第二批:2, 3, 4
  • 第三批:3, 4, 5
  • 第四批:111, 222, 333
  • 第五批:222, 333, 444
  • 第6批:333, 444, 555
  • 第7批:1, 2, 3
  • 等等

我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列数据,但我不知道如何防止一个序列包含不同的sequence_id(例如,序列{}应该无效,因为它混合了来自序列0和序列{}的元素)。在

以下是我失败的尝试:

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555], 
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True)\
                .repeat(-1)\
                .flat_map(lambda x, y: x.batch(3))\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

结果是:

^{pr2}$

Tags: 数据id元素datatfasbatchit
1条回答
网友
1楼 · 发布于 2024-04-25 09:24:45

您可以使用filter()来判断sequence_id是否一致。因为filter()转换当前不支持嵌套数据集作为输入,所以您需要zip()。在

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True) \
                .flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
                .filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
                .map(lambda x,y:x)\
                .repeat(-1)\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

[[  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]
 [222 333 444]
 [333 444 555]
 [  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]]

相关问题 更多 >