无法加载带有自定义正则化类的训练Keras模型

0 投票
1 回答
83 浏览
提问于 2025-04-14 17:58

我正在用自己的数据集训练PointNet3D物体分类模型,跟着Keras的教程在做:https://keras.io/examples/vision/pointnet/#point-cloud-classification-with-pointnet

在训练的过程中,我一切都进行得很顺利,但在训练完成后,我遇到了加载已训练模型的问题。主要的问题我觉得出在下面这部分,OrthogonalRegularizer这个类的对象在我保存模型时可能没有正确注册:


@keras.saving.register_keras_serializable('OrthogonalRegularizer')
class OrthogonalRegularizer(keras.regularizers.Regularizer):

    def __init__(self, num_features, **kwargs):
        super(OrthogonalRegularizer, self).__init__(**kwargs)
        self.num_features = num_features
        self.l2reg = 0.001
        self.eye = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.math.reduce_sum(self.l2reg * tf.square(xxt - self.eye))

    def get_config(self):
        config = {}
        config.update({"num_features": self.num_features, "l2reg": self.l2reg, "eye": self.eye})
        return config

def tnet(inputs, num_features):
    # Initialise bias as the identity matrix
    bias = keras.initializers.Constant(np.eye(num_features).flatten())
    reg = OrthogonalRegularizer(num_features)

    x = conv_bn(inputs, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = dense_bn(x, 128)
    x = layers.Dense(
        num_features * num_features,
        kernel_initializer="zeros",
        bias_initializer=bias,
        activity_regularizer=reg,
    )(x)
    feat_T = layers.Reshape((num_features, num_features))(x)
    # Apply affine transformation to input features
    return layers.Dot(axes=(2, 1))([inputs, feat_T])

训练完成后,当我尝试用以下方式加载模型时,出现了一个错误:

model.save('my_model.h5')
model = keras.models.load_model('my_model.h5', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})

错误信息是:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-05f700f433a8> in <cell line: 2>()
      1 model.save('my_model.h5')
----> 2 model = keras.models.load_model('my_model.h5', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})

2 frames
/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_api.py in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    260 
    261     # Legacy case.
--> 262     return legacy_sm_saving_lib.load_model(
    263         filepath, custom_objects=custom_objects, compile=compile, **kwargs
    264     )

/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py in from_config(cls, config)
    868             return cls(**config)
    869         except Exception as e:
--> 870             raise TypeError(
    871                 f"Error when deserializing class '{cls.__name__}' using "
    872                 f"config={config}.\n\nException encountered: {e}"

TypeError: Error when deserializing class 'Dense' using config={'name': 'dense_2', 
'trainable': True, 'dtype': 'float32', 'units': 9, 'activation': 'linear', 'use_bias': 

True, 'kernel_initializer': {'module': 'keras.initializers', 'class_name': 'Zeros', 
'config': {}, 'registered_name': None}, 'bias_initializer': {'module': 
'keras.initializers', 'class_name': 'Constant', 'config': {'value': {'class_name': 
'__numpy__', 'config': {'value': [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 'dtype': 
'float64'}}}, 'registered_name': None}, 'kernel_regularizer': None, 'bias_regularizer': 
None, 'activity_regularizer': {'module': None, 'class_name': 'OrthogonalRegularizer', 
'config': {'num_features': 3, 'l2reg': 0.001, 'eye': {'class_name': '__tensor__', 
'config': {'value': [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], 'dtype': 
'float32'}}}, 'registered_name': 'OrthogonalRegularizer>OrthogonalRegularizer'}, 
'kernel_constraint': None, 'bias_constraint': None}.

Exception encountered: object.__init__() takes exactly one argument (the instance to initialize)

到目前为止,我理解的是,在保存模型时,我没有正确保存OrthogonalRegularizer这个类的对象。请告诉我我哪里做错了。

这里上传了代码的简化版本,可以在这个colab笔记本中查看: https://colab.research.google.com/drive/1akpfoOBVAWThsZl7moYywuZIuXt_vWCU?usp=sharing

一个可能类似的问题是这个:在Keras中加载自定义正则化器

1 个回答

1

你不需要调用

super(OrthogonalRegularizer, self).__init__(**kwargs)

因为在 keras.regularizers.Regularizer 里,构造函数并没有被定义。还有,没必要存储那个不能被序列化的张量 eye: self.eye。而且,最好是在使用这个张量的地方附近创建它。

修改后的代码应该是这样的:

@keras.saving.register_keras_serializable('OrthogonalRegularizer')
class OrthogonalRegularizer(keras.regularizers.Regularizer):

    def __init__(self, num_features, **kwargs):
        self.num_features = num_features
        self.l2reg = 0.001

    def call(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        eye = tf.eye(self.num_features)
        return tf.math.reduce_sum(self.l2reg * tf.square(xxt - eye))


    def get_config(self):
        return {"num_features": self.num_features, "l2reg": self.l2reg}

另外,使用更新的模型保存和加载方式会更好:

model.save('my_model')
model = keras.models.load_model('my_model', custom_objects={'OrthogonalRegularizer': OrthogonalRegularizer})

撰写回答