如何在Python中编写有效的类装饰器?

6 投票
3 回答
806 浏览
提问于 2025-04-17 06:06

我刚写了一个类装饰器,像下面这样,试图为目标类中的每个方法添加调试支持:

import unittest
import inspect

def Debug(targetCls):
   for name, func in inspect.getmembers(targetCls, inspect.ismethod):
      def wrapper(*args, **kwargs):
         print ("Start debug support for %s.%s()" % (targetCls.__name__, name));
         result = func(*args, **kwargs)
         return result
      setattr(targetCls, name, wrapper)
   return targetCls

@Debug
class MyTestClass:
   def TestMethod1(self):
      print 'TestMethod1'

   def TestMethod2(self):
      print 'TestMethod2'

class Test(unittest.TestCase):

   def testName(self):
      for name, func in inspect.getmembers(MyTestClass, inspect.ismethod):
         print name, func

      print '~~~~~~~~~~~~~~~~~~~~~~~~~~'
      testCls = MyTestClass()

      testCls.TestMethod1()
      testCls.TestMethod2()


if __name__ == "__main__":
   #import sys;sys.argv = ['', 'Test.testName']
   unittest.main()

运行上面的代码,结果是:

Finding files... done.
Importing test modules ... done.

TestMethod1 <unbound method MyTestClass.wrapper>
TestMethod2 <unbound method MyTestClass.wrapper>
~~~~~~~~~~~~~~~~~~~~~~~~~~
Start debug support for MyTestClass.TestMethod2()
TestMethod2
Start debug support for MyTestClass.TestMethod2()
TestMethod2
----------------------------------------------------------------------
Ran 1 test in 0.004s

OK

你会发现'TestMethod2'被打印了两次。

这有什么问题吗?我对Python中的装饰器理解正确吗?

有没有什么解决办法?顺便说一下,我不想给类中的每个方法都添加装饰器。

3 个回答

0

问题不在于如何写一个有效的类装饰器;显然这个类是被装饰的,而且它不会只是抛出异常,你可以顺利地添加你想要加到类里的代码。所以很明显,你需要找的是装饰器里的错误,而不是在纠结自己是否写了一个有效的装饰器。

在这个情况下,问题出在闭包上。在你的 Debug 装饰器里,你对 namefunc 进行了循环,每次循环时你都会定义一个叫 wrapper 的函数,这个函数是一个闭包,它可以访问循环变量。问题是,一旦开始下一次循环,循环变量所指的内容就已经改变了。但你只会在整个循环结束后才调用这些 wrapper 函数。所以每个被装饰的方法最终调用的都是循环的最后一个值:在这个例子中,就是 TestMethod2

在这种情况下,我会做一个方法级别的装饰器,但因为你不想每个方法都显式地加装饰器,所以你可以做一个类装饰器,它会遍历所有的方法,并把它们传递给方法装饰器。这是可行的,因为你并没有通过闭包让 wrapper 访问你的循环变量;相反,你是把循环变量所指的内容的引用传递给一个函数(这个函数是装饰器,它构造并返回一个 wrapper);一旦这样做了,下一次循环迭代重新绑定循环变量就不会影响到这个 wrapper 函数。

0

这是一个很常见的问题。你可能认为 wrapper 是一个闭包,它能捕捉到当前的 func 参数,但实际上并不是这样。如果你没有把当前的 func 值传给 wrapper,那么它的值只会在循环结束后被查找,所以你得到的只是最后一个值。

你可以这样做:

def Debug(targetCls):

   def wrap(name,func): # use the current func
      def wrapper(*args, **kwargs):
         print ("Start debug support for %s.%s()" % (targetCls.__name__, name));
         result = func(*args, **kwargs)
         return result
      return wrapper

   for name, func in inspect.getmembers(targetCls, inspect.ismethod):
      setattr(targetCls, name, wrap(name, func))
   return targetCls
3

考虑这个循环:

for name, func in inspect.getmembers(targetCls, inspect.ismethod):
        def wrapper(*args, **kwargs):
            print ("Start debug support for %s.%s()" % (targetCls.__name__, name))

wrapper 最终被调用时,它会查找 name 的值。如果在本地找不到,它会在 for-loop 的扩展范围内查找,并且找到了。但是到那时,for-loop 已经结束,name 变成了循环中的最后一个值,也就是 TestMethod2

所以每次调用 wrapper 时,name 的值都是 TestMethod2

解决这个问题的方法是创建一个扩展的范围,让 name 绑定到正确的值。这可以通过一个函数 closure 来实现,使用默认参数值。默认参数值在定义时就被评估并固定,并且绑定到同名的变量上。

def Debug(targetCls):
    for name, func in inspect.getmembers(targetCls, inspect.ismethod):
        def closure(name=name,func=func):
            def wrapper(*args, **kwargs):
                print ("Start debug support for %s.%s()" % (targetCls.__name__, name))
                result = func(*args, **kwargs)
                return result
            return wrapper        
        setattr(targetCls, name, closure())
    return targetCls

在评论中,eryksun 提出了一个更好的解决方案:

def Debug(targetCls):
    def closure(name,func):
        def wrapper(*args, **kwargs):
            print ("Start debug support for %s.%s()" % (targetCls.__name__, name));
            result = func(*args, **kwargs)
            return result
        return wrapper        
    for name, func in inspect.getmembers(targetCls, inspect.ismethod):
        setattr(targetCls, name, closure(name,func))
    return targetCls

现在 closure 只需要解析一次。每次调用 closure(name,func) 时,都会创建一个自己的函数范围,并且 namefunc 的值会被正确绑定。

撰写回答