为什么Keras的假阴性计数为负(例如10)?

2024-05-12 23:23:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图调试一个keras模型,用于执行非常糟糕的文本二进制分类。你知道吗

我关掉了所有的铃铛和哨子,我试着用两个不同的数据集(相同的X数据集,但不同的Y标签)来匹配它:

  • Y0:所有Y=0
  • Y1:所有Y=1

每个数据集大约1K个样本。你知道吗

然后,我试着对模型进行多次拟合,改变一些参数,如学习率、层的大小、在单词的一个热编码和整数编码表示之间进行切换。你知道吗

令人惊讶的是,这个测试揭示了一些指标给了我错误的结果:

Stats of the model when fitted with Y0 and Y1 datasets

为什么FN计数是负数?你知道吗

我做了一些检查。 似乎负计数(例如:-87)会影响其他指标,如回忆(偶数>1)、平均有效误差、准确度

下面是我正在运行的(已编译的)代码:

import keras_metrics

DEFAULT_INNER_ACTIVATION = 'relu'
DEFAULT_OUTPUT_ACTIVATION = 'softplus'

    def __init__(self, sentence_max_lenght, ctx_max_len, dense_features_dim, vocab_size):

        lstm_input_phrase = keras.layers.Input(shape=(sentence_max_lenght,), name='L0_STC_MyApp')

        lstm_emb_phrase = keras.layers.LSTM(DEFAULT_MODEL_L1_STC_DIM, name='L1_STC_MyApp')(lstm_emb_phrase)
        lstm_emb_phrase = keras.layers.Dense(DEFAULT_MODEL_L2_STC_DIM, name='L2_STC_MyApp', activation=DEFAULT_INNER_ACTIVATION)(lstm_emb_phrase)

        x = keras.layers.Dense(DEFAULT_MODEL_L3_DIM, activation=DEFAULT_INNER_ACTIVATION)(lstm_emb_phrase)
        x = keras.layers.Dense(DEFAULT_MODEL_L4_DIM, activation=DEFAULT_INNER_ACTIVATION)(x)

        main_output = keras.layers.Dense(2, activation=DEFAULT_OUTPUT_ACTIVATION)(x)

        self.model = keras.models.Model(inputs=lstm_input_phrase,
                                        outputs=main_output)

        optimizer = keras.optimizers.Adam(lr=self.LEARNING_RATE)

        self.model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy',
                                                                                     'mae',
                                                                                     keras_metrics.precision(),
                                                                                     keras_metrics.recall(),
                                                                                     keras_metrics.binary_precision(),
                                                                                     keras_metrics.binary_recall(),
                                                                                     keras_metrics.binary_true_positive(),
                                                                                     keras_metrics.binary_true_negative(),
                                                                                     keras_metrics.binary_false_positive(),
                                                                                     keras_metrics.binary_false_negative()])


    def fit(self, x_lstm_phrase, x_lstm_context, x_lstm_pos, x_dense, y):

        x_arr = keras.preprocessing.sequence.pad_sequences(x_lstm_phrase)

        y_onehot = MyNN.onehot_transform(y)

        return self.model.fit(x_arr,
                       y_onehot,
                       batch_size=self.batch_size,
                       epochs=self.max_epochs,
                       validation_split=self.validation_split,
                       callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',
                                                                min_delta=0.0001,
                                                                patience=self.patience,
                                                                restore_best_weights=True
                                                                )])



这是我从终端得到的第一部分输出的片段:

注意:这里有两个警告。我不认为这些警告会影响这个问题。你知道吗

Using TensorFlow backend.
2019-04-01 23:26:59.479064: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
WARNING:tensorflow:From [path_to_myApp]\venv\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (f
rom tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From [path_to_myApp]\venv\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.p
ython.ops.math_ops) is deprecated and will be removed in a future version.

 16/618 [..............................] - ETA: 38s - loss: 0.7756 - binary_accuracy: 0.5000 - mean_absolute_error: 0.5007 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 16.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
 32/618 [>.............................] - ETA: 23s - loss: 0.7740 - binary_accuracy: 0.5000 - mean_absolute_error: 0.5000 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 32.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
 48/618 [=>............................] - ETA: 17s - loss: 0.7725 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4994 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 48.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
 64/618 [==>...........................] - ETA: 15s - loss: 0.7711 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4988 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 64.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
 80/618 [==>...........................] - ETA: 13s - loss: 0.7697 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4982 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 80.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
 96/618 [===>..........................] - ETA: 12s - loss: 0.7682 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4976 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 96.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
112/618 [====>.........................] - ETA: 11s - loss: 0.7666 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4970 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 112.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
128/618 [=====>........................] - ETA: 10s - loss: 0.7650 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4963 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 128.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
144/618 [=====>........................] - ETA: 9s - loss: 0.7634 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4956 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 144.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00 
160/618 [======>.......................] - ETA: 9s - loss: 0.7617 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4949 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 160.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
176/618 [=======>......................] - ETA: 8s - loss: 0.7600 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4941 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 176.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
192/618 [========>.....................] - ETA: 8s - loss: 0.7582 - binary_accuracy: 0.5000 - mean_absolute_error: 0.4934 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 192.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00

当我开始得到一个负的FN计数时:


256/618 [===========>..................] - ETA: 5s - loss: 0.3052 - binary_accuracy: 0.8750 - mean_absolute_error: 0.2778 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 256.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
272/618 [============>.................] - ETA: 5s - loss: 0.2965 - binary_accuracy: 0.8824 - mean_absolute_error: 0.2791 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 272.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
288/618 [============>.................] - ETA: 5s - loss: 0.2882 - binary_accuracy: 0.8889 - mean_absolute_error: 0.2807 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 288.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
304/618 [=============>................] - ETA: 4s - loss: 0.2804 - binary_accuracy: 0.8947 - mean_absolute_error: 0.2828 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 304.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
320/618 [==============>...............] - ETA: 4s - loss: 0.2730 - binary_accuracy: 0.9000 - mean_absolute_error: 0.2853 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 320.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
336/618 [===============>..............] - ETA: 4s - loss: 0.2659 - binary_accuracy: 0.9048 - mean_absolute_error: 0.2882 - precision: 1.0000 - recall: 1.0000 - precision_1: 1.0000 - recall_1: 1.0000 - true_positive: 336.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: 0.0000e+00
352/618 [================>.............] - ETA: 4s - loss: 0.2591 - binary_accuracy: 0.8864 - mean_absolute_error: 0.2914 - precision: 1.0000 - recall: 1.0455 - precision_1: 1.0000 - recall_1: 1.0455 - true_positive: 368.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -16.0000  
368/618 [================>.............] - ETA: 3s - loss: 0.2526 - binary_accuracy: 0.8696 - mean_absolute_error: 0.2950 - precision: 1.0000 - recall: 1.0870 - precision_1: 1.0000 - recall_1: 1.0870 - true_positive: 400.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -32.0000
384/618 [=================>............] - ETA: 3s - loss: 0.2464 - binary_accuracy: 0.8542 - mean_absolute_error: 0.2989 - precision: 1.0000 - recall: 1.1250 - precision_1: 1.0000 - recall_1: 1.1250 - true_positive: 432.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -48.0000
400/618 [==================>...........] - ETA: 3s - loss: 0.2404 - binary_accuracy: 0.8400 - mean_absolute_error: 0.3031 - precision: 1.0000 - recall: 1.1600 - precision_1: 1.0000 - recall_1: 1.1600 - true_positive: 464.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -64.0000
416/618 [===================>..........] - ETA: 3s - loss: 0.2346 - binary_accuracy: 0.8269 - mean_absolute_error: 0.3076 - precision: 1.0000 - recall: 1.1923 - precision_1: 1.0000 - recall_1: 1.1923 - true_positive: 496.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -80.0000
432/618 [===================>..........] - ETA: 2s - loss: 0.2291 - binary_accuracy: 0.8148 - mean_absolute_error: 0.3124 - precision: 1.0000 - recall: 1.2222 - precision_1: 1.0000 - recall_1: 1.2222 - true_positive: 528.0000 - true_negative: 0.0000e+00 - false_positive: 0.0000e+00 - false_negative: -96.0000

你知道我如何解决这个问题吗?你知道吗

编辑:

我试图删除所有使用的keras\u度量,只留下二进制\u精度。你知道吗

仍然存在这个问题,因为损耗和valu损耗几乎为零,而精度保持在0.5左右。

考虑到数据集的特殊性,它意味着#TP=#FN(对于Y1)和#TN+#FP(对于Y0)

怎么可能用这种损失来衡量准确度呢?你知道吗

这跟我用的是

Dense(2, activation='softplus') 

层作为输出?你知道吗

你知道吗?你知道吗


Tags: falsetrueerrormeanprecisionkerasetabinary
1条回答
网友
1楼 · 发布于 2024-05-12 23:23:44

在一些测试之后,我将激活函数从softplus更改为softmax。你知道吗

所有的度量现在都在正确的范围内,即使分类器的性能很差。你知道吗

每小时

相关问题 更多 >