保存PyML.classifiers.multi.OneAgainstRest(SVM())对象吗?

0 投票
2 回答
832 浏览
提问于 2025-04-15 21:49

我正在使用PYML来构建一个多类线性支持向量机(SVM)。在训练完这个SVM后,我想保存这个分类器,这样下次运行时就可以直接使用,而不需要重新训练。不幸的是,这个分类器的.save()功能并没有实现,尝试用pickle(包括标准的pickle和cPickle)保存时出现了以下错误信息:

pickle.PicklingError: Can't pickle : it's not found as __builtin__.PySwigObject

有没有人知道有什么解决办法,或者有没有其他库可以避免这个问题?谢谢。

编辑/更新
我现在正在用以下代码训练并尝试保存分类器:

mc = multi.OneAgainstRest(SVM());
mc.train(dataset_pyml,saveSpace=False);
    for i, classifier in enumerate(mc.classifiers):
        filename=os.path.join(prefix,labels[i]+".svm");
        classifier.save(filename);

注意,我现在是用PyML的保存机制,而不是用pickle,并且我在训练函数中传递了“saveSpace=False”。但是,我仍然遇到错误:

ValueError: in order to save a dataset you need to train as: s.train(data, saveSpace = False)

不过,我已经传递了saveSpace=False...那么,我该如何保存这个分类器呢?

附言
我使用这个项目是pyimgattr,如果你想要一个完整的可测试示例...程序是用"./pyimgattr.py train"运行的...这会导致你看到这个错误。此外,还有一个版本信息的说明:

[michaelsafyan@codemage /Volumes/Storage/classes/cse559/pyimgattr]$ python
Python 2.6.1 (r261:67515, Feb 11 2010, 00:51:29) 
[GCC 4.2.1 (Apple Inc. build 5646)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import PyML
>>> print PyML.__version__
0.7.0

2 个回答

2

在multi.py的第96行,调用了"self.classifiers[i].train(datai)",但没有传递"**args"这个参数。所以当你调用"mc.train(data, saveSpace=False)"时,这个saveSpace参数就丢失了。这就是为什么如果你尝试单独保存你的多分类分类器时会出现错误信息。如果你把这一行改成传递所有参数,就可以单独保存每个分类器了:

#!/usr/bin/python

import numpy

from PyML.utils import misc
from PyML.evaluators import assess
from PyML.classifiers.svm import SVM, loadSVM
from PyML.containers.labels import oneAgainstRest
from PyML.classifiers.baseClassifiers import Classifier
from PyML.containers.vectorDatasets import SparseDataSet
from PyML.classifiers.composite import CompositeClassifier

class OneAgainstRestFixed(CompositeClassifier) :

    '''A one-against-the-rest multi-class classifier'''

    def train(self, data, **args) :
        '''train k classifiers'''

        Classifier.train(self, data, **args)

        numClasses = self.labels.numClasses
        if numClasses <= 2:
            raise ValueError, 'Not a multi class problem'

        self.classifiers = [self.classifier.__class__(self.classifier)
                            for i in range(numClasses)]

        for i in range(numClasses) :
            # make a copy of the data; this is done in case the classifier modifies the data
            datai = data.__class__(data, deepcopy = self.classifier.deepcopy)
            datai =  oneAgainstRest(datai, data.labels.classLabels[i])

            self.classifiers[i].train(datai, **args)

        self.log.trainingTime = self.getTrainingTime()

    def classify(self, data, i):

        r = numpy.zeros(self.labels.numClasses, numpy.float_)
        for j in range(self.labels.numClasses) :
            r[j] = self.classifiers[j].decisionFunc(data, i)

        return numpy.argmax(r), numpy.max(r)

    def preproject(self, data) :

        for i in range(self.labels.numClasses) :
            self.classifiers[i].preproject(data)

    test = assess.test

train_data = """
0 1:1.0 2:0.0 3:0.0 4:0.0
0 1:0.9 2:0.0 3:0.0 4:0.0
1 1:0.0 2:1.0 3:0.0 4:0.0
1 1:0.0 2:0.8 3:0.0 4:0.0
2 1:0.0 2:0.0 3:1.0 4:0.0
2 1:0.0 2:0.0 3:0.9 4:0.0
3 1:0.0 2:0.0 3:0.0 4:1.0
3 1:0.0 2:0.0 3:0.0 4:0.9
"""
file("foo_train.data", "w").write(train_data.lstrip())

test_data = """
0 1:1.1 2:0.0 3:0.0 4:0.0
1 1:0.0 2:1.2 3:0.0 4:0.0
2 1:0.0 2:0.0 3:0.6 4:0.0
3 1:0.0 2:0.0 3:0.0 4:1.4
"""
file("foo_test.data", "w").write(test_data.lstrip())

train = SparseDataSet("foo_train.data")
mc = OneAgainstRestFixed(SVM())
mc.train(train, saveSpace=False)

test = SparseDataSet("foo_test.data")
print [mc.classify(test, i) for i in range(4)]

for i, classifier in enumerate(mc.classifiers):
    classifier.save("foo.model.%d" % i)

classifiers = []
for i in range(4):
    classifiers.append(loadSVM("foo.model.%d" % i))

mcnew = OneAgainstRestFixed(SVM())
mcnew.labels = misc.Container()
mcnew.labels.addAttributes(test.labels, ['numClasses', 'classLabels'])
mcnew.classifiers = classifiers
print [mcnew.classify(test, i) for i in range(4)]
0

获取一个更新的PyML版本。从0.7.4版本开始,你可以使用.save()和.load()来保存和加载OneAgainstRest分类器;在这个版本之前,保存和加载分类器比较麻烦,而且容易出错。

撰写回答