如何让'import x'返回types.ModuleType的子类?
Python的import
语句能不能通过导入钩子返回一个types.ModuleType
的子类呢?我想重写__getattribute__
,这样当某个模块外的代码引用不在__all__
里的名字时,就能显示一个运行时警告。
我知道怎么在sys.modules['foo']
被导入后替换它。我想要的是在模块被导入时,能够对符合某种模式的模块进行处理,这样导入的代码就有机会触发警告。
Python并不太赞成把东西分得太清楚,比如公有和私有。这个想法并不是为了阻止你的模块用户输入from somemodule import sys
;而是作为一个文档工具。这种处理方式应该能让你更容易地记录模块的API,通过包含正确的__all__
。它应该能帮助你避免错误地把sys
引用为somemodule.sys
,而是简单地使用import sys
。
3 个回答
其实你根本不需要什么导入钩子。简单来说,只要在Python的模块注册表 sys.modules
中用你想要的名字放一个对象的引用,以后在这个Python会话中(即使是在其他模块里)任何 import
语句都会导入这个对象的引用。
import types, sys
class MyModuleType(types.ModuleType):
pass
sys.modules["foo"] = MyModuleType("foo")
import foo
print type(foo) # MyModuleType
Python根本不在乎 sys.modules
中的对象到底是模块还是其他类型的对象。你甚至可以把一个 int
放进去,Python也不会在意。
sys.modules["answer"] = 42
import answer
你可以参考这个ActiveState的例子,可能可以做成这样:
# safemodule.py
import sys
import types
import warnings
class EncapsulationWarning(RuntimeWarning): pass
class ModuleWrapper(types.ModuleType):
def __init__(self, context):
self.context = context
super(ModuleWrapper, self).__init__(
context.__name__,
context.__doc__)
def __getattribute__(self, key):
context = object.__getattribute__(self, 'context')
if hasattr(context, '__all__') and key not in context.__all__:
warnings.warn('%s not in %s.__all__' % (key, context.__name__),
EncapsulationWarning,
2)
return context.__getattribute__(key)
if 'old_import' not in globals():
old_import = __import__
def safe_import(*args, **kwargs):
m = old_import(*args, **kwargs)
return ModuleWrapper(m)
__builtins__['__import__'] = safe_import
然后,你可以像这样使用它:
C:\temp>python
ActivePython 2.5.2.2 (ActiveState Software Inc.) based on
Python 2.5.2 (r252:60911, Mar 27 2008, 17:57:18) [MSC v.1310 32 bit (Intel)] on
win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import safemodule
>>> import sys
>>> type(sys)
<class 'safemodule.ModuleWrapper'>
>>>
当然,你也可以根据需要调整,只包装某些模块等等。
因为我之前没理解你的问题,所以我想再试一次(同时保留我之前的回答以供参考)。
这里有一个替代方案,它不需要导入钩子。这个方法可以很方便地在每个模块中使用:包含这段代码的模块会有特别的 __getattribute__()
行为,而其他模块则会像往常一样工作。
class StrictModule(types.ModuleType):
def __getattribute__(self, key):
if key is "__dict__": # fake __dict__ with only visible attributes
return dict((k, v) for k, v in globals().iteritems()
if k.startswith("__") or k in __all__)
if (key.startswith("__") or key in __all__) and key in globals():
return globals()[key]
else:
raise AttributeError("'module' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
globals()[key] = value
sys.modules[__name__] = StrictModule(__name__)
请记住,这种“限制”其实很容易绕过,只需调用普通的 module
类型的 __getattribute__()
(或者说 object
类型的)。我感觉你是想为你的模块提供某种“私有成员”的限制,但其实这样做几乎没有什么意义。