使用Keras处理看不见的类

2024-05-20 11:11:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我用Python制作了一个Keras模型,该模型对字符串输入是公司、个人还是地址进行分类。 模型基于12000个字符串数据进行训练。每个输入有1到5个单词。 这就是模型:

transformerVectoriser = ColumnTransformer(transformers=[('vector char', CountVectorizer(analyzer='char', ngram_range=(3, 6), max_features = 2000), 'text'),
                                                        ('vector word', CountVectorizer(analyzer='word', ngram_range=(1, 1), max_features = 4000), 'text')],
                                          remainder='passthrough') # Default is to drop untransformed columns


features = transformerVectoriser.fit_transform(features)


model = Sequential()
model.add(Dense(100, input_dim = features.shape[1], activation = 'relu')) # input layer requires input_dim param
model.add(Dense(200, activation = 'relu'))
model.add(Dense(100, activation = 'relu'))
model.add(Dense(50, activation = 'relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))

结果如下:

                precision    recall  f1-score   support

company         0.97         0.92      0.95       636
person          0.93         0.97      0.95       697
address         1.00         1.00      1.00       667

accuracy                               0.97      2000
macro avg       0.97         0.96      0.97      2000
weighted avg    0.97         0.97      0.97      2000

例如,如果我想用字符串输入进行预测:

input_strs = ['Amazon Inc', 'Jeff Bezos', 'Elon Musk', '24 Avenue Paris']

它将其分类为:

 ['company', 'person', 'person', 'address']

该模型运行良好,但我注意到,如果输入一个字符串,例如,表示电话号码或只是一些随机数字或一些随机字符串,有时会犯很大的错误。 例如,如果我输入:

['+435 542 425 54 24', '426266245', 'as long as the']

我得到的结果是:

 ['address', 'company', 'address']

我的问题是,如何处理一些看不见的类? 如果字符串输入不满足一些可以正确分类的基本“形式”,我如何处理这种情况


2条回答

另一个更直接但我认为长期而言不太准确的解决方案是通过以下方式在softmax之后添加一些简单的逻辑:

import numpy as np

#Initialize
softmaxoutput=np.double([1,2,3])
classes=['company','person','address']

#Let's play a littlebit with the outputs
softmaxoutput[0]=0.3
softmaxoutput[1]=0.3
softmaxoutput[2]=1-(softmaxoutput[0]+softmaxoutput[1])

#Let's decide the predicted class...
result=np.argmax(softmaxoutput)
predicted_class=classes[result]

uncertainity_threshold=0.5

#...but make an exception that...
if np.amax(softmaxoutput)<=uncertainity_threshold:
    predicted_class='hmmm...'

#And finally let's show the result
print(predicted_class)

…您可以通过参数不确定性_阈值轻松管理此附加逻辑的“效果”。如果将其值设为1,您肯定会得到与当前解决方案相同的结果……但通过减小此值,您对非逻辑分类的头痛会稍微减轻。用手测试似乎是最佳值是很简单的

您还可以找到其他解决方案,但这将是处理该问题的第二种方法

我建议你做一个名为“嗯…”之类的分类。和-用大量不属于您感兴趣的类的字符串填充此类别

制作一个小脚本很容易,它可以阅读互联网或书籍一段时间,每次它发现一个不是公司、地址或个人的字符串,它就会将其保存到“hmmm…”类别

因此,您有一个DNN,它将每个“怪异”输入分类到类“hmmm…”

您还可以找到其他解决方案,但这将是一种管理方法

相关问题 更多 >