让我们考虑一个玩具数据集ordered,它有两个特性:
value
(例如1, 2, 3, 4, 5, 111, 222, 333, 444, 555
)sequence_id
(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1
)该数据基本上由两个串联的扁平序列组成,1, 2, 3, 4, 5
(序列0
)和{1
)。在
我想生成大小为t
(比如3
)的序列,由来自同一序列(sequence_id
)的连续元素组成,我不希望序列具有属于不同的sequence_id
的元素。在
例如,在不进行任何洗牌的情况下,我希望获得以下批次:
1, 2, 3
2, 3, 4
3, 4, 5
111, 222, 333
222, 333, 444
333, 444, 555
1, 2, 3
我知道如何使用tf.data.Dataset.window
或tf.data.Dataset.batch
生成序列数据,但我不知道如何防止一个序列包含不同的sequence_id
(例如,序列{0
和序列{
以下是我失败的尝试:
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True)\
.repeat(-1)\
.flat_map(lambda x, y: x.batch(3))\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
结果是:
^{pr2}$
您可以使用
filter()
来判断sequence_id
是否一致。因为filter()
转换当前不支持嵌套数据集作为输入,所以您需要zip()
。在相关问题 更多 >
编程相关推荐