我想用以下代码训练具有多标签分类的线性SVM:
from sklearn.svm import LinearSVC
from sklearn.multioutput import MultiOutputClassifier
import numpy as np
data = np.loadtxt('tictac_multi.txt')
X = data[:,:9]
y = data[:,9:]
clf = MultiOutputClassifier(LinearSVC(random_state=0, tol=1e-5, C=100, penalty='l2',max_iter=2000))
clf.fit(X, y)
print(clf.score(X, y))
数据集可以在这里找到https://www.connellybarnes.com/work/class/2016/deep_learning_graphics/proj1/tictac_multi.txt
我尝试调整不同的参数,如C、tol、max_iter和其他参数。线性支持向量机模型仍然不能很好地训练。无论我调整任何参数,训练精度仍然小于0.01
上述代码的输出为:
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
Warning (from warnings module):
File "C:\Users\hyu14\AppData\Local\Programs\Python\Python38-32\lib\site-packages\sklearn\svm\_base.py", line 946
warnings.warn("Liblinear failed to converge, increase "
ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
0.011601282246985194
与当前代码相比,精度为0.0116
它看起来像一个“TicTacToe”数据集(从文件名和格式来看)
假设datset的前九列提供了游戏中特定时刻9个单元格的描述,而其他九列表示与良好移动对应的单元格,则可以逐个单元格训练分类器单元格,以预测单元格是否为良好移动
所以,你实际上需要训练9个二进制分类器,而不是一个。基于这个想法,我在下面的代码中草拟了一个非常简单的方法。在列车/测试(80/20)中拆分数据集后,从简单交叉验证开始:
如您所见,我为分类器使用了一些非默认选项(
dual=False, class_weight='balanced'
):它们只是一个有根据的猜测,您应该进行更多调查,以更好地了解数据和问题,然后为您的模型寻找最佳参数(例如,网格搜索)下面是分数:
正如你所看到的,它们不是很好,但远远不是0
现在,在整个列车数据集上训练模型,并在测试数据上评估性能:
下面是表演:
相关问题 更多 >
编程相关推荐