如何使用预处理层保存keras模型?

2024-03-19 04:27:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个keras模型,保存到h5文件中。当我从h5文件加载模型时,预测是正确的,但是解码预测失败,除非我处于创建h5文件的相同过程中

这些是前/后处理层:

# Mapping characters to integers
char_to_num = layers.experimental.preprocessing.StringLookup(
    vocabulary=list(characters), num_oov_indices=0, mask_token=None
)

# Mapping integers back to original characters
num_to_char = layers.experimental.preprocessing.StringLookup(
    vocabulary=char_to_num.get_vocabulary(), num_oov_indices=0, mask_token=None, invert=True
)

有关预测失败的确切位置,请参见下面的注释:

def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    # results is correct here
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        # res is incorrect here
        output_text.append(res)
    return output_text

我尝试过酸洗num_to_char和char_to_num,因为这是它失败的原因,但它仍然失败

培训过程:

torch.save(num_to_char.get_config(), 'num_to_char.pt')
torch.save(char_to_num.get_config(), 'char_to_num.pt')

生产过程:

char_to_num = layers.experimental.preprocessing.StringLookup.from_config(torch.load('char_to_num.pt'))
num_to_char = layers.experimental.preprocessing.StringLookup.from_config(torch.load('num_to_char.pt'))

Tags: 文件toptconfig过程layersrestorch