使用tf.gather在keras中进行自定义正则化

2024-05-14 10:16:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试在keras中实现一个自定义正则化器。其思想是,正则化的范围仅限于数据集中的两列。以下是一个玩具数据集:

# dataset
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
x = data.data # x.shape() == (569, 30)
y = data.target

下面是我如何编写正则化器的:

import tensorflow as tf # tf.__version__ == 2.0.0

class MyRegularizer(tf.keras.regularizers.Regularizer):

    def __init__(self, strength):
        self.strength = strength

    def __call__(self, x):
        # print(tf.shape(x))
        return self.strength * tf.reduce_sum(tf.subtract(tf.gather(params=x,indices=[29],axis=1),
                                                         tf.gather(params=x,indices=[28],axis=1)
                                                        )
                                            )

这是一个玩具模型:

# model
inputs = tf.keras.layers.Input(shape=x.shape[1])
dense = tf.keras.layers.Dense(1, kernel_regularizer=MyRegularizer(0.01)
                             )(inputs)
model = tf.keras.models.Model(inputs = inputs, outputs = dense)
model.compile(loss='binary_crossentropy')
model.summary()

model.fit(x,y)

我得到的错误如下:

InvalidArgumentError: segment_ids[0] = 28 is out of range [0, 1)

我确实尝试在模型之外检查正则化器的输出函数

tf.reduce_sum(tf.subtract(tf.gather(x,[29],axis=1),tf.gather(x,[28],axis=1)))

而且运行良好

所以,发送给正则化子的张量的形状可能有问题。我不知道如何解决这个问题(使用变量名、数据类型、输入形状,所有这些都没有运气)。没有正则化器的模型拟合没有任何误差

互联网上关于上述错误的线程围绕嵌入维度展开,我没有找到适合我的解决方案


Tags: 数据模型selfdatamodeltfstrengthkeras
1条回答
网友
1楼 · 发布于 2024-05-14 10:16:29

在子类中,“call()”方法中传递的参数“x”是层内核(权重)。由于在数据层中有一个单元格,“tf.gather”方法无法在内核的第二个轴上找到索引[28]的元素

InvalidArgumentError: segment_ids[0] = 28 is out of range [0, 1)

如果要获得与[28]输入相对应的权重;我认为下面的代码可以工作(将轴值更改为零):

import tensorflow as tf # tf.__version__ == 2.0.0

class MyRegularizer(tf.keras.regularizers.Regularizer):

def __init__(self, strength):
    self.strength = strength

def __call__(self, x):
    # print(tf.shape(x))
    return self.strength * tf.reduce_sum(tf.subtract(tf.gather(params=x,indices=[29],axis=0),
                                                     tf.gather(params=x,indices=[28],axis=0)
                                                    )
                                        )

相关问题 更多 >

    热门问题