如何序列化/反序列化pybrain网络?

9 投票
1 回答
3886 浏览
提问于 2025-04-16 07:56

PyBrain 是一个 Python 库,它提供了易于使用的人工神经网络功能。

我在使用 pickle 或 cPickle 来保存和读取 PyBrain 网络时遇到了问题。

下面是一个例子:

from pybrain.datasets            import SupervisedDataSet
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import cPickle as pickle
import numpy as np 

#generate some data
np.random.seed(93939393)
data = SupervisedDataSet(2, 1)
for x in xrange(10):
    y = x * 3
    z = x + y + 0.2 * np.random.randn()  
    data.addSample((x, y), (z,))

#build a network and train it    

net1 = buildNetwork( data.indim, 2, data.outdim )
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True)
for i in xrange(4):
    trainer1.trainEpochs(1)
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0])

这是上面代码的输出结果:

Total error: 201.501998476
    value after 0 epochs: 2.79
Total error: 152.487616382
    value after 1 epochs: 5.44
Total error: 120.48092561
    value after 2 epochs: 7.56
Total error: 97.9884043452
    value after 3 epochs: 8.41

你可以看到,随着训练的进行,网络的总错误率在下降。同时,预测的值也逐渐接近预期的值 12。

现在我们将做一个类似的练习,但这次会涉及到保存和读取:

print 'creating net2'
net2 = buildNetwork(data.indim, 2, data.outdim)
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0])

#So far, so good. Let's test pickle
pickle.dump(net2, open('testNetwork.dump', 'w'))
net2 = pickle.load(open('testNetwork.dump'))
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
print 'loaded net2 using pickle, continue training'
for i in xrange(1, 4):
        trainer2.trainEpochs(1)
        print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0])

这是这个代码块的输出结果:

creating net2
Total error: 176.339378639
    value after 1 epochs: 5.45
loaded net2 using pickle, continue training
Total error: 123.392181859
    value after 1 epochs: 5.45
Total error: 94.2867637623
    value after 2 epochs: 5.45
Total error: 78.076711114
    value after 3 epochs: 5.45

如你所见,训练似乎对网络有一些影响(报告的总错误值继续下降),但是网络的输出值却停留在了第一次训练时的一个值上。

我需要注意有什么缓存机制导致了这种错误的行为吗?有没有更好的方法来保存和读取 PyBrain 网络?

相关的版本信息:

  • Python 2.6.5 (r265:79096, 2010年3月19日, 21:48:26) [MSC v.1500 32位 (Intel)]
  • Numpy 1.5.1
  • cPickle 1.71
  • pybrain 0.3

附言:我在项目网站上创建了 一个错误报告,会在 Stack Overflow 和错误跟踪器上保持更新。

1 个回答

11

原因

造成这种情况的原因是PyBrain模块中对参数(.params)和导数(.derivs)的处理方式。实际上,所有网络参数都存储在一个数组里,但每个ModuleConnection对象可以访问“自己的”.params,不过这些其实只是总数组的一部分视图。这种设计允许在本地和整个网络中对同一数据结构进行读写操作。

显然,这种视图链接在进行序列化和反序列化时会丢失。

解决方案

在从文件加载后插入

net2.sorted = False
net2.sortModules()

(这会重新创建这种共享),这样就应该可以正常工作了。

撰写回答