如何有条件地缩放Keras Lambda层中的值?

2024-04-27 00:00:22 发布

您现在位置:Python中文网/ 问答频道 /正文

输入张量rnn_pv的形状是(?, 48, 1)。我想缩放这个张量中的每个元素,所以我尝试使用Lambda层,如下所示:

rnn_pv_scale = Lambda(lambda x: 1 if x >=1000 else x/1000.0 )(rnn_pv)

但错误来了:

^{pr2}$

那么,实现这个功能的正确方法是什么呢?在


Tags: 方法lambda功能元素if错误else形状
1条回答
网友
1楼 · 发布于 2024-04-27 00:00:22

不能使用Python控制流语句(如if else语句)在模型定义中执行条件操作。相反,您需要使用Keras后端中定义的方法。由于您使用TensorFlow作为后端,因此可以使用tf.where()来实现:

import tensorflow as tf

scaled = Lambda(lambda x: tf.where(x >= 1000, tf.ones_like(x), x/1000.))(input_tensor)

或者,要支持所有后端,可以创建一个掩码来执行此操作:

^{pr2}$

更新:支持所有后端的另一种方法是使用K.switch方法:

from keras import backend as K

scaled = Lambda(lambda x: K.switch(x >= 1000., K.ones_like(x), x / 1000.))(input_tensor)

相关问题 更多 >