编写一个类装饰器,使所有方法都应用装饰器
我正在尝试写一个类装饰器,这个装饰器可以对类里面的所有方法都应用另一个装饰器:
import inspect
def decorate_func(func):
def wrapper(*args, **kwargs):
print "before"
ret = func(*args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(func, attr))
return wrapper
def decorate_class(cls):
for name, meth in inspect.getmembers(cls, inspect.ismethod):
setattr(cls, name, decorate_func(meth))
return cls
@decorate_class
class MyClass(object):
def __init__(self):
self.a = 10
print "__init__"
def foo(self):
print self.a
@staticmethod
def baz():
print "baz"
@classmethod
def bar(cls):
print "bar"
obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()
这个方法差不多可以用,但@classmethod
需要特别处理:
$ python test.py
before
__init__
after
before
10
after
baz
baz
before
Traceback (most recent call last):
File "test.py", line 44, in <module>
obj.bar()
File "test.py", line 7, in wrapper
ret = func(*args, **kwargs)
TypeError: bar() takes exactly 1 argument (2 given)
有没有什么好的办法来解决这个问题呢?我查看了被@classmethod
装饰的方法,但我没有找到什么能把它们和其他“类型”的方法区分开的东西。
更新
这里是完整的解决方案,记录一下(使用描述符来很好地处理@staticmethod
和@classmethod
,还有aix的技巧来区分@classmethod
和普通方法):
import inspect
class DecoratedMethod(object):
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
print "before"
ret = self.func(obj, *args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(self.func, attr))
return wrapper
class DecoratedClassMethod(object):
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
print "before"
ret = self.func(*args, **kwargs)
print "after"
return ret
for attr in "__module__", "__name__", "__doc__":
setattr(wrapper, attr, getattr(self.func, attr))
return wrapper
def decorate_class(cls):
for name, meth in inspect.getmembers(cls):
if inspect.ismethod(meth):
if inspect.isclass(meth.im_self):
# meth is a classmethod
setattr(cls, name, DecoratedClassMethod(meth))
else:
# meth is a regular method
setattr(cls, name, DecoratedMethod(meth))
elif inspect.isfunction(meth):
# meth is a staticmethod
setattr(cls, name, DecoratedClassMethod(meth))
return cls
@decorate_class
class MyClass(object):
def __init__(self):
self.a = 10
print "__init__"
def foo(self):
print self.a
@staticmethod
def baz():
print "baz"
@classmethod
def bar(cls):
print "bar"
obj = MyClass()
obj.foo()
obj.baz()
MyClass.baz()
obj.bar()
MyClass.bar()
3 个回答
0
上面的回答不直接适用于python3。根据其他一些很棒的回答,我想出了以下解决方案:
import inspect
import types
import networkx as nx
def override_methods(cls):
for name, meth in inspect.getmembers(cls):
if name in cls.methods_to_override:
setattr(cls, name, cls.DecorateMethod(meth))
return cls
@override_methods
class DiGraph(nx.DiGraph):
methods_to_override = ("add_node", "remove_edge", "add_edge")
class DecorateMethod:
def __init__(self, func):
self.func = func
def __get__(self, obj, cls=None):
def wrapper(*args, **kwargs):
ret = self.func(obj, *args, **kwargs)
obj._dirty = True # This is the attribute I want to update
return ret
return wrapper
def __init__(self):
super().__init__()
self._dirty = True
现在,每当调用元组 methods_to_override
中的方法时,都会设置一个“脏标志”。当然,其他任何东西也可以放在这里。并不一定要在需要重写方法的类中包含 DecorateMethod
类。不过,由于 DecorateMethod
使用了特定的类属性,我更倾向于将其作为类属性来处理。
1
(评论太长了)
我擅自为你的解决方案增加了一个功能,可以指定哪些方法需要被装饰:
def class_decorator(*method_names):
def wrapper(cls):
for name, meth in inspect.getmembers(cls):
if name in method_names or len(method_names) == 0:
if inspect.ismethod(meth):
if inspect.isclass(meth.im_self):
# meth is a classmethod
setattr(cls, name, VerifyTokenMethod(meth))
else:
# meth is a regular method
setattr(cls, name, VerifyTokenMethod(meth))
elif inspect.isfunction(meth):
# meth is a staticmethod
setattr(cls, name, VerifyTokenMethod(meth))
return cls
return wrapper
使用方法:
@class_decorator('some_method')
class Foo(object):
def some_method(self):
print 'I am decorated'
def another_method(self):
print 'I am NOT decorated'
11
你可以用 inspect.isclass(meth.im_self)
这个代码来判断 meth
是否是一个类方法。
def decorate_class(cls):
for name, meth in inspect.getmembers(cls, inspect.ismethod):
if inspect.isclass(meth.im_self):
print '%s is a class method' % name
# TODO
...
return cls