Python 类装饰器参数

39 投票
9 回答
37045 浏览
提问于 2025-04-17 02:43

我正在尝试在Python中给我的类装饰器传递可选参数。

class Cache(object):
    def __init__(self, function, max_hits=10, timeout=5):
        self.function = function
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.


@Cache
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2

我在代码中设置的第二个装饰器,想用参数来覆盖默认的设置(在我的__init__函数中,默认是max_hits=10, timeout=5),但是它没有正常工作,我遇到了一个错误,提示TypeError: __init__() takes at least 2 arguments (3 given)。我尝试了很多解决办法,也看了相关的文章,但还是无法解决这个问题。

有没有什么想法可以帮我解决这个问题?谢谢!

9 个回答

11

我更倾向于把包装器放在类的 __call__ 方法里面:

更新:这个方法在 Python 3.6 中已经测试过了,所以我不确定在更高版本或更早版本中是否适用。

class Cache:
    def __init__(self, max_hits=10, timeout=5):
        # Remove function from here and add it to the __call__
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, function):
        def wrapper(*args):
            value = function(*args)
            # saving to cache codes
            return value
        return wrapper

@Cache()
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2
30
@Cache
def double(...): 
   ...

等同于

def double(...):
   ...
double=Cache(double)

@Cache(max_hits=100, timeout=50)
def double(...):
   ...

等同于

def double(...):
    ...
double = Cache(max_hits=100, timeout=50)(double)

Cache(max_hits=100, timeout=50)(double)Cache(double) 的含义差别很大。

试图让 Cache 同时处理这两种情况是不明智的。

你可以使用一个装饰器工厂,它可以接受可选的 max_hitstimeout 参数,并返回一个装饰器:

class Cache(object):
    def __init__(self, function, max_hits=10, timeout=5):
        self.function = function
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.

def cache_hits(max_hits=10, timeout=5):
    def _cache(function):
        return Cache(function,max_hits,timeout)
    return _cache

@cache_hits()
def double(x):
    return x * 2

@cache_hits(max_hits=100, timeout=50)
def double(x):
    return x * 2

附言:如果 Cache 类除了 __init____call__ 方法外没有其他方法,你可以考虑把所有代码放到 _cache 函数里,完全去掉 Cache 类。

32

@Cache(max_hits=100, timeout=50) 这个语句实际上是在调用 __init__(max_hits=100, timeout=50),所以你没有满足 function 这个参数的要求。

你可以通过一个包装方法来实现你的装饰器,这个方法可以检查是否有函数存在。如果找到了一个函数,它就可以返回缓存对象。否则,它可以返回一个将作为装饰器使用的包装函数。

class _Cache(object):
    def __init__(self, function, max_hits=10, timeout=5):
        self.function = function
        self.max_hits = max_hits
        self.timeout = timeout
        self.cache = {}

    def __call__(self, *args):
        # Here the code returning the correct thing.

# wrap _Cache to allow for deferred calling
def Cache(function=None, max_hits=10, timeout=5):
    if function:
        return _Cache(function)
    else:
        def wrapper(function):
            return _Cache(function, max_hits, timeout)

        return wrapper

@Cache
def double(x):
    return x * 2

@Cache(max_hits=100, timeout=50)
def double(x):
    return x * 2

撰写回答