TF2/Keras切片张量使用[:,:,0]

2024-03-29 13:48:38 发布

您现在位置:Python中文网/ 问答频道 /正文

在TF 2.0测试版中,我尝试:

x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>

在TF 1.x中,我可以做到:

^{pr2}$

而且会很好的。如何在tf2.0中实现这一点?我想

tf.split(x, 2, axis=2)

可能有用,但是我想用切片而不是硬编码2(轴2的尺寸)。在


Tags: noneinputlayerstf切片unknownkerassplit
1条回答
网友
1楼 · 发布于 2024-03-29 13:48:38

区别在于Input返回的对象表示一个层,而不是任何类似于占位符或张量的对象。所以上面tf2.0代码中的x是层对象,而tf1.x代码中的x是张量的占位符。在

可以定义切片层来执行操作。有现成的层可用,但是对于这样一个简单的切片,Lambda层非常容易阅读,并且可能最接近您在tf1.x中使用的切片方式

像这样:

input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

可在keras模型中使用,如下所示:

^{pr2}$

当然,以上是特定于keras模型的。相反,如果使用张量而不是keras层对象,则切片的工作方式与之前完全相同。像这样:

my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]

print(my_tensor.shape)
print(sliced.shape)

输出:

(8, 240, 2)
(8, 240)

如期而至

相关问题 更多 >