使用joblib缓存类中部分方法的正确方法

9 投票
1 回答
1745 浏览
提问于 2025-04-17 21:09

我正在写一个类,这个类里面有一些计算量大的方法,还有一些参数是用户想要反复调整的,这些参数和计算没有关系。

这个类主要是用来做可视化的,不过我这里举个简单的例子:

class MyClass(object):

    def __init__(self, x, name, mem=None):

        self.x = x
        self.name = name
        if mem is not None:
            self.square = mem.cache(self.square)

    def square(self, x):
        """This is the 'computation heavy' method."""
        return x ** 2

    def report(self):
        """Use the results of the computation and a tweakable parameter."""
        print "Here you go, %s" % self.name
        return self.square(self.x)

基本的想法是,用户可能想用相同的 x 值,但不同的 name 参数来创建很多这个类的实例。我想让用户可以提供一个 joblib.Memory 对象,这样可以缓存计算的部分,这样他们就可以在不重新计算平方数组的情况下,给很多不同的名字“报告”结果。

(我知道这听起来有点奇怪。用户需要为每个名字创建不同的类实例,是因为他们实际上会和一个看起来像这样的接口函数进行交互。

def myfunc(x, name, mem=None):
    theclass = MyClass(x, name, mem)
    theclass.report()

但我们暂时先不讨论这个)。


根据 joblib 的文档,我用 self.square = mem.cache(self.square) 这一行来缓存 square 函数。问题是,因为 self 对于不同的实例是不同的,即使参数相同,数组每次还是会被重新计算。

我猜正确的处理方式是把这一行改成

self.square = mem.cache(self.square, ignore=["self"])

不过,这种方法有没有什么缺点呢?有没有更好的方法来实现缓存呢?

1 个回答

1

来自文档

如果你想在一个类里面使用缓存,推荐的做法是缓存一个纯函数,然后在你的类中使用这个缓存的函数。

因为你希望内存缓存是可选的,我建议可以这样做:

def square_function(x):
    """This is the 'computation heavy' method."""
    print '    square_function is executing, not using cached result.'
    return x ** 2

class MyClass(object):

    def __init__(self, x, name, mem=None):
        self.x = x
        self.name = name
        if mem is not None:
            self._square_function = mem.cache(square_function)
        else:
            self._square_function = square_function

    def square(self, x):
        return self._square_function(x)

    def report(self):
        print "Here you go, %s" % self.name
        return self.square(self.x)


from tempfile import mkdtemp
cachedir = mkdtemp()

from joblib import Memory
memory = Memory(cachedir=cachedir, verbose=0)

objects = [
    MyClass(42, 'Alice (cache)', memory),
    MyClass(42, 'Bob (cache)', memory),
    MyClass(42, 'Charlie (no cache)')
]

for obj in objects:
    print obj.report()

执行结果是:

Here you go, Alice (cache)
    square_function is executing, not using cached result.
1764
Here you go, Bob (cache)
1764
Here you go, Charlie (no cache)
    square_function is executing, not using cached result.
1764

撰写回答