如何在tensorflow中自动合并形状?

2024-06-17 11:12:42 发布

您现在位置:Python中文网/ 问答频道 /正文

对于高阶张量,我不知道如何自动操纵它的形状。在

例如:

                                #   0  1  2  3   -1
a.shape                         # [?, ?, ?, ?, ..., ?]
merge_dims(a, [0]   ).shape     # [?* ?, ?, ?, ..., ?]
merge_dims(a, [1, 2]).shape     # [?, ?* ?* ?, ..., ?]
                                #   ^  ^  ^  ^    ^

使用merge_dims,由位置号标记的逗号应该变成倍数,从而生成一个低阶张量。在

谢谢:)


Tags: 标记merge形状逗号shape低阶倍数dims
1条回答
网友
1楼 · 发布于 2024-06-17 11:12:42

这是一个函数,可以执行以下操作:

import tensorflow as tf

def merge_dims(x, axis, num=1):
    # x: input tensor
    # axis: first dimension to merge
    # num: number of merges
    shape = tf.shape(x)
    new_shape = tf.concat([
        shape[:axis],
        [tf.reduce_prod(shape[axis:axis + num + 1])],
        shape[axis + num + 1:]], axis=0)
    return tf.reshape(x, new_shape)

with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.ones([2, 4, 6, 8, 10])
    print(sess.run(tf.shape(merge_dims(a, 0))))
    # [ 8  6  8 10]
    print(sess.run(tf.shape(merge_dims(a, 1, num=2))))
    # [  2 192  10]

相关问题 更多 >