Sklearn线性支持向量机在多标签分类中不能训练

2024-06-11 20:43:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我想用以下代码训练具有多标签分类的线性SVM:

from sklearn.svm import LinearSVC
from sklearn.multioutput import MultiOutputClassifier
import numpy as np

data = np.loadtxt('tictac_multi.txt')
X = data[:,:9]
y = data[:,9:]

clf = MultiOutputClassifier(LinearSVC(random_state=0, tol=1e-5, C=100, penalty='l2',max_iter=2000))
clf.fit(X, y)
print(clf.score(X, y))

数据集可以在这里找到https://www.connellybarnes.com/work/class/2016/deep_learning_graphics/proj1/tictac_multi.txt

我尝试调整不同的参数,如C、tol、max_iter和其他参数。线性支持向量机模型仍然不能很好地训练。无论我调整任何参数,训练精度仍然小于0.01

上述代码的输出为:

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.

Warning (from warnings module):
  File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
    warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
0.011601282246985194

与当前代码相比,精度为0.0116


Tags: tofromsklearnusersappdatafilemoduleliblinear
1条回答
网友
1楼 · 发布于 2024-06-11 20:43:12

它看起来像一个“TicTacToe”数据集(从文件名和格式来看)

假设datset的前九列提供了游戏中特定时刻9个单元格的描述,而其他九列表示与良好移动对应的单元格,则可以逐个单元格训练分类器单元格,以预测单元格是否为良好移动

所以,你实际上需要训练9个二进制分类器,而不是一个。基于这个想法,我在下面的代码中草拟了一个非常简单的方法。在列车/测试(80/20)中拆分数据集后,从简单交叉验证开始:

import numpy as np
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.metrics import classification_report
import pandas as pd

# Load data, creating a Dataframe holding input and outputs
df = pd.read_csv('tictac_multi.txt', sep=' ', header=None)[list(range(18))].copy()
df.columns = pd.MultiIndex.from_product((('input', 'output'), [f'x{i}' for i in range(1, 10)]))

# split dataset 80/20 (also shuffle it)
X_train, X_test, y_train, y_test = train_test_split(df['input'].values, df['output'].values, test_size=0.2, random_state=42)

# Get scores from cross validation 
scores = {
    s: cross_validate(
        LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5), 
        X_train, y_train[:, i], cv=5, scoring=['balanced_accuracy', 'precision', 'recall', 'f1_weighted'], 
        n_jobs=-1,
    ) for i, (s, clf) in enumerate(sorted(clfs.items()))
}

如您所见,我为分类器使用了一些非默认选项(dual=False, class_weight='balanced'):它们只是一个有根据的猜测,您应该进行更多调查,以更好地了解数据和问题,然后为您的模型寻找最佳参数(例如,网格搜索)

下面是分数:

{'x1': {'fit_time': array([0.01000571, 0.00814652, 0.00937247, 0.00622296, 0.00536656]),
  'score_time': array([0.01159358, 0.00597596, 0.00835085, 0.00647163, 0.00619125]),
  'test_balanced_accuracy': array([0.52209841, 0.51820565, 0.53743952, 0.55455645, 0.53620968]),
  'test_precision': array([0.25454545, 0.25      , 0.26611227, 0.27659574, 0.26295585]),
  'test_recall': array([0.5060241 , 0.52016129, 0.51612903, 0.5766129 , 0.55241935]),
  'test_f1_weighted': array([0.56543736, 0.55328701, 0.58232694, 0.57711117, 0.56292617])},
 'x2': {'fit_time': array([0.00737047, 0.00885296, 0.00616217, 0.00707698, 0.0071764 ]),
  'score_time': array([0.00650406, 0.00595641, 0.00623679, 0.00636506, 0.00567913]),
  'test_balanced_accuracy': array([0.57367382, 0.5342687 , 0.55287658, 0.56565243, 0.57909451]),
  'test_precision': array([0.22520661, 0.20041754, 0.21073559, 0.22037422, 0.23175966]),
  'test_recall': array([0.5828877 , 0.51336898, 0.56684492, 0.56684492, 0.57446809]),
  'test_f1_weighted': array([0.6183652 , 0.60068273, 0.59707974, 0.61584554, 0.63060231])},
 'x3': {'fit_time': array([0.0067966 , 0.00759745, 0.00617337, 0.00679278, 0.00650382]),
  'score_time': array([0.00605631, 0.00537109, 0.00551271, 0.00665474, 0.00649571]),
  'test_balanced_accuracy': array([0.52683332, 0.54103562, 0.56227539, 0.53312408, 0.51986383]),
  'test_precision': array([0.25502008, 0.26639344, 0.28367347, 0.26035503, 0.25      ]),
  'test_recall': array([0.51626016, 0.52845528, 0.56275304, 0.53441296, 0.53036437]),
  'test_f1_weighted': array([0.56805171, 0.58208858, 0.59506983, 0.56776364, 0.55079222])},
 'x4': {'fit_time': array([0.00649667, 0.00767159, 0.00802064, 0.00769711, 0.00611663]),
  'score_time': array([0.00572419, 0.00529647, 0.00616765, 0.00592041, 0.00609517]),
  'test_balanced_accuracy': array([0.53369766, 0.57259312, 0.57644138, 0.55746825, 0.51877354]),
  'test_precision': array([0.19791667, 0.22290389, 0.22540984, 0.21489362, 0.18930041]),
  'test_recall': array([0.51351351, 0.58602151, 0.59139785, 0.54301075, 0.49462366]),
  'test_f1_weighted': array([0.6005693 , 0.615313  , 0.61784599, 0.61784823, 0.58924053])},
 'x5': {'fit_time': array([0.00650501, 0.005898  , 0.00682783, 0.00629449, 0.00635648]),
  'score_time': array([0.00553894, 0.0059135 , 0.00625896, 0.00583744, 0.00580502]),
  'test_balanced_accuracy': array([0.51108635, 0.50499149, 0.52183641, 0.53230958, 0.51296946]),
  'test_precision': array([0.30185185, 0.29735234, 0.31163708, 0.322     , 0.30522088]),
  'test_recall': array([0.53094463, 0.47557003, 0.51465798, 0.52272727, 0.49350649]),
  'test_f1_weighted': array([0.5248707 , 0.53861778, 0.54612005, 0.55679291, 0.54217533])},
 'x6': {'fit_time': array([0.00703621, 0.00908065, 0.00665092, 0.00619102, 0.00814819]),
  'score_time': array([0.00568652, 0.00626183, 0.00584817, 0.00574327, 0.00552726]),
  'test_balanced_accuracy': array([0.55457928, 0.55569106, 0.50701258, 0.53690769, 0.56919396]),
  'test_precision': array([0.2145749 , 0.21621622, 0.18480493, 0.20416667, 0.22540984]),
  'test_recall': array([0.56084656, 0.55026455, 0.47619048, 0.51851852, 0.57894737]),
  'test_f1_weighted': array([0.60241544, 0.61008882, 0.5813744 , 0.60080544, 0.6130977 ])},
 'x7': {'fit_time': array([0.0070405 , 0.00908256, 0.00702643, 0.00635576, 0.00632381]),
  'score_time': array([0.00546646, 0.00674367, 0.00542998, 0.00671315, 0.00549483]),
  'test_balanced_accuracy': array([0.53124816, 0.52187224, 0.54180051, 0.57438252, 0.52764072]),
  'test_precision': array([0.27054108, 0.26235741, 0.27659574, 0.30364372, 0.26824034]),
  'test_recall': array([0.52325581, 0.53488372, 0.55642023, 0.58365759, 0.48638132]),
  'test_f1_weighted': array([0.56745684, 0.54860915, 0.56677092, 0.5996452 , 0.57954721])},
 'x8': {'fit_time': array([0.00761437, 0.00997519, 0.006984  , 0.00623441, 0.00683069]),
  'score_time': array([0.00540686, 0.00635052, 0.00645804, 0.00535131, 0.00548935]),
  'test_balanced_accuracy': array([0.51471322, 0.56996108, 0.52712724, 0.5443143 , 0.55319282]),
  'test_precision': array([0.18661258, 0.22292994, 0.192607  , 0.20408163, 0.20874751]),
  'test_recall': array([0.49462366, 0.56451613, 0.53513514, 0.54054054, 0.56756757]),
  'test_f1_weighted': array([0.58328382, 0.62374708, 0.57815794, 0.60051373, 0.59779516])},
 'x9': {'fit_time': array([0.00723267, 0.0069263 , 0.00828266, 0.00672913, 0.00750995]),
  'score_time': array([0.00545311, 0.00556946, 0.00732398, 0.0056181 , 0.00555682]),
  'test_balanced_accuracy': array([0.53490307, 0.55281703, 0.58447809, 0.52272419, 0.54294236]),
  'test_precision': array([0.26388889, 0.27868852, 0.29811321, 0.25506073, 0.27198364]),
  'test_recall': array([0.53413655, 0.54618474, 0.63453815, 0.5060241 , 0.532     ]),
  'test_f1_weighted': array([0.56987212, 0.58922553, 0.59075641, 0.56631422, 0.5819019 ])}}

正如你所看到的,它们不是很好,但远远不是0

现在,在整个列车数据集上训练模型,并在测试数据上评估性能:

def train_clfs(clfs, X, y):
    return {s: clf.fit(X, y[:, i]) for i, (s, clf) in enumerate(sorted(clfs.items()))}


def get_predictions(clfs, inp):
    return {s: clf.predict(inp) for s, clf in clfs.items()}

# Train the classifiers
clfs = {s: LinearSVC(random_state=0, dual=False, class_weight='balanced', tol=1e-5) for s in sorted(df['output'].columns)}
clfs = train_clfs(clfs, X_train, y_train)

# Try them on the test values
pred = get_predictions(clfs, X_test)

# Get the classification report for each classifier
cl_report = {s: classification_report(y_test[:, i], p) for i, (s, p) in enumerate(sorted(pred.items()))}

下面是表演:

x1
              precision    recall  f1-score   support

           0       0.76      0.52      0.62       988
           1       0.25      0.49      0.33       323

    accuracy                           0.51      1311
   macro avg       0.50      0.51      0.48      1311
weighted avg       0.63      0.51      0.55      1311


x2
              precision    recall  f1-score   support

           0       0.87      0.56      0.68      1086
           1       0.22      0.58      0.31       225

    accuracy                           0.57      1311
   macro avg       0.54      0.57      0.50      1311
weighted avg       0.75      0.57      0.62      1311


x3
              precision    recall  f1-score   support

           0       0.79      0.50      0.61       998
           1       0.26      0.57      0.36       313

    accuracy                           0.52      1311
   macro avg       0.53      0.54      0.49      1311
weighted avg       0.66      0.52      0.55      1311


x4
              precision    recall  f1-score   support

           0       0.84      0.54      0.65      1061
           1       0.22      0.57      0.32       250

    accuracy                           0.54      1311
   macro avg       0.53      0.55      0.49      1311
weighted avg       0.72      0.54      0.59      1311


x5
              precision    recall  f1-score   support

           0       0.72      0.53      0.61       926
           1       0.31      0.50      0.38       385

    accuracy                           0.52      1311
   macro avg       0.51      0.52      0.50      1311
weighted avg       0.60      0.52      0.54      1311


x6
              precision    recall  f1-score   support

           0       0.85      0.57      0.69      1077
           1       0.22      0.54      0.31       234

    accuracy                           0.57      1311
   macro avg       0.53      0.56      0.50      1311
weighted avg       0.74      0.57      0.62      1311


x7
              precision    recall  f1-score   support

           0       0.81      0.55      0.65      1021
           1       0.25      0.53      0.34       290

    accuracy                           0.55      1311
   macro avg       0.53      0.54      0.50      1311
weighted avg       0.68      0.55      0.59      1311


x8
              precision    recall  f1-score   support

           0       0.84      0.55      0.66      1069
           1       0.21      0.53      0.30       242

    accuracy                           0.55      1311
   macro avg       0.52      0.54      0.48      1311
weighted avg       0.72      0.55      0.60      1311


x9
              precision    recall  f1-score   support

           0       0.79      0.54      0.64      1006
           1       0.26      0.52      0.35       305

    accuracy                           0.54      1311
   macro avg       0.52      0.53      0.49      1311
weighted avg       0.67      0.54      0.57      1311

相关问题 更多 >