为什么我可以序列化Theano函数的实例方法,但不能序列化普通实例方法?

7 投票
1 回答
1036 浏览
提问于 2025-04-18 17:40

在使用joblib来并行处理一些涉及Theano函数的模型拟合代码时,我遇到了一些看起来很奇怪的行为。

考虑这个非常简单的例子:

from joblib import Parallel, delayed
import theano
from theano import tensor as te
import numpy as np

class TheanoModel(object):
    def __init__(self):
        X = te.dvector('X')
        Y = (X ** te.log(X ** 2)).sum()
        self.theano_get_Y = theano.function([X], Y)

    def get_Y(self, x):
        return self.theano_get_Y(x)

def run(niter=100):
    x = np.random.randn(1000)
    model = TheanoModel()
    pool = Parallel(n_jobs=-1, verbose=1, pre_dispatch='all')

    # this fails with `TypeError: can't pickle instancemethod objects`...
    results = pool(delayed(model.get_Y)(x) for _ in xrange(niter))

    # # ... but this works! Why?
    # results = pool(delayed(model.theano_get_Y)(x) for _ in xrange(niter))

if __name__ == '__main__':
    run()

我明白为什么第一个例子会失败,因为.get_Y()显然是TheanoModel的一个实例方法。我不明白的是,为什么第二个例子却能工作,因为XYtheano_get_Y()仅在TheanoModel__init__()方法中声明。theano_get_Y()TheanoModel实例创建之前是无法被评估的。那么,它应该也被视为一个实例方法,因此应该无法被序列化吧?实际上,即使我明确声明XYTheanoModel实例的属性,它仍然可以正常工作。

有没有人能解释一下这是怎么回事?


更新

为了说明我为什么觉得这种行为特别奇怪,这里有一些其他的可调用成员对象的例子,它们的第一个参数不是self

from joblib import Parallel, delayed
import theano
from theano import tensor as te
import numpy as np

class TheanoModel(object):
    def __init__(self):
        X = te.dvector('X')
        Y = (X ** te.log(X ** 2)).sum()
        self.theano_get_Y = theano.function([X], Y)
        def square(x):
            return x ** 2
        self.member_function = square
        self.static_method = staticmethod(square)
        self.lambda_function = lambda x: x ** 2

def run(niter=100):
    x = np.random.randn(1000)
    model = TheanoModel()
    pool = Parallel(n_jobs=-1, verbose=1, pre_dispatch='all')

    # # not allowed: `TypeError: can't pickle function objects`
    # results = pool(delayed(model.member_function)(x) for _ in xrange(niter))

    # # not allowed: `TypeError: can't pickle function objects`
    # results = pool(delayed(model.lambda_function)(x) for _ in xrange(niter))

    # # also not allowed: `TypeError: can't pickle staticmethod objects`
    # results = pool(delayed(model.static_method)(x) for _ in xrange(niter))

    # but this is totally fine!?
    results = pool(delayed(model.theano_get_Y)(x) for _ in xrange(niter))

if __name__ == '__main__':
    run()

除了theano.function,它们都无法被序列化!

1 个回答

4

Theano的函数其实不是普通的Python函数。它们更像是Python对象,重写了__call__这个方法。这就意味着你可以像调用函数一样来使用它们,但实际上它们是某个自定义类的对象。因此,你可以把它们进行序列化保存。

撰写回答