使用scikit学习区分相似的类别

2024-05-12 13:26:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我想把文档中的文本分为不同的类别。每个文档只能进入以下类别之一:PR、AR、KID、SAR。你知道吗

我找到了一个使用scikit learn的示例,并且能够使用它:

import numpy
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.svm import LinearSVC
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.multiclass import OneVsRestClassifier
from pandas import DataFrame

def build_data_frame(path, classification):
    rows = []
    index = []

    f = open(path, mode = 'r', encoding="utf8")
    txt = f.read()

    rows.append({'text': txt, 'class': classification})
    index.append(path)

    data_frame = DataFrame(rows, index=index)
    return data_frame

# Categories
PR = 'PR'
AR = 'AR'
KID = 'KID'
SAR = 'SAR'

# Training documents
SOURCES = [
    (r'C:/temp_training/PR/PR1.txt', PR),
    (r'C:/temp_training/PR/PR2.txt', PR),
    (r'C:/temp_training/PR/PR3.txt', PR),
    (r'C:/temp_training/PR/PR4.txt', PR),
    (r'C:/temp_training/PR/PR5.txt', PR),
    (r'C:/temp_training/AR/AR1.txt', AR),
    (r'C:/temp_training/AR/AR2.txt', AR),
    (r'C:/temp_training/AR/AR3.txt', AR),
    (r'C:/temp_training/AR/AR4.txt', AR),
    (r'C:/temp_training/AR/AR5.txt', AR),
    (r'C:/temp_training/KID/KID1.txt', KID),
    (r'C:/temp_training/KID/KID2.txt', KID),
    (r'C:/temp_training/KID/KID3.txt', KID),
    (r'C:/temp_training/KID/KID4.txt', KID),
    (r'C:/temp_training/KID/KID5.txt', KID),
    (r'C:/temp_training/SAR/SAR1.txt', SAR),
    (r'C:/temp_training/SAR/SAR2.txt', SAR),
    (r'C:/temp_training/SAR/SAR3.txt', SAR),
    (r'C:/temp_training/SAR/SAR4.txt', SAR),
    (r'C:/temp_training/SAR/SAR5.txt', SAR)
]

# Real documents
TESTS = [
    (r'C:/temp_testing/PR/PR1.txt'),
    (r'C:/temp_testing/PR/PR2.txt'),
    (r'C:/temp_testing/PR/PR3.txt'),
    (r'C:/temp_testing/PR/PR4.txt'),
    (r'C:/temp_testing/PR/PR5.txt'),
    (r'C:/temp_testing/AR/AR1.txt'),
    (r'C:/temp_testing/AR/AR2.txt'),
    (r'C:/temp_testing/AR/AR3.txt'),
    (r'C:/temp_testing/AR/AR4.txt'),
    (r'C:/temp_testing/AR/AR5.txt'),
    (r'C:/temp_testing/KID/KID1.txt'),
    (r'C:/temp_testing/KID/KID2.txt'),
    (r'C:/temp_testing/KID/KID3.txt'),
    (r'C:/temp_testing/KID/KID4.txt'),
    (r'C:/temp_testing/KID/KID5.txt'),
    (r'C:/temp_testing/SAR/SAR1.txt'),
    (r'C:/temp_testing/SAR/SAR2.txt'),
    (r'C:/temp_testing/SAR/SAR3.txt'),
    (r'C:/temp_testing/SAR/SAR4.txt'),
    (r'C:/temp_testing/SAR/SAR5.txt')
]

data_train = DataFrame({'text': [], 'class': []})
for path, classification in SOURCES:
    data_train = data_train.append(build_data_frame(path, classification))

data_train = data_train.reindex(numpy.random.permutation(data_train.index))

examples = []

for path in TESTS:
    f = open(path, mode = 'r', encoding = 'utf8')
    txt = f.read()

    examples.append(txt)

target_names = [PR, AR, KID, SAR]

classifier = Pipeline([
    ('vectorizer', CountVectorizer(ngram_range=(1, 2), analyzer='word', strip_accents='unicode', stop_words='english')),
    ('tfidf', TfidfTransformer()),
    ('clf', OneVsRestClassifier(LinearSVC()))])
classifier.fit(data_train['text'], data_train['class'])
predicted = classifier.predict(examples)

print(predicted)

输出:

['PR' 'PR' 'PR' 'PR' 'PR' 'AR' 'AR' 'AR' 'AR' 'AR' 'KID' 'KID' 'KID' 'KID'
 'KID' 'AR' 'AR' 'AR' 'SAR' 'AR']

PR,AR和KID都是公认的。你知道吗

然而,SAR文档(最后5个)除了其中一个之外没有正确分类。SAR和AR非常相似,这可以解释为什么算法会混淆。你知道吗

我试着使用n-grams值,但是1(最小值)和2(最大值)似乎给出了最好的结果。你知道吗

  • 你知道如何提高AR和SAR分类的精度吗?

  • 有没有办法显示特定文档的识别率?i、 e.PR(70%),意味着算法对预测有70%的信心

如果您需要文档,这里是数据集:http://1drv.ms/21dnL6j


Tags: pathfrom文档importtxtdatatrainingtrain
1条回答
网友
1楼 · 发布于 2024-05-12 13:26:14

这不是一个严格意义上的编程问题,所以我建议您尝试将其发布到一个更与数据科学相关的堆栈中。你知道吗

无论如何,你可以尝试一些事情:

  • 使用其他分类器。你知道吗
  • 使用网格搜索调整分类器超参数。你知道吗
  • 使用OneVsOne代替OneVsAll作为策略。这可能会帮助您区分SAR和AR
  • 对于“显示特定文档的识别百分比”,可以使用来自某些模型的概率输出。使用classifier.predict_proba函数而不是classifier.predict函数。你知道吗

祝你好运!你知道吗

相关问题 更多 >