如何让'import x'返回types.ModuleType的子类?

7 投票
3 回答
1546 浏览
提问于 2025-04-16 17:31

Python的import语句能不能通过导入钩子返回一个types.ModuleType的子类呢?我想重写__getattribute__,这样当某个模块外的代码引用不在__all__里的名字时,就能显示一个运行时警告。

我知道怎么在sys.modules['foo']被导入后替换它。我想要的是在模块被导入时,能够对符合某种模式的模块进行处理,这样导入的代码就有机会触发警告。

Python并不太赞成把东西分得太清楚,比如公有和私有。这个想法并不是为了阻止你的模块用户输入from somemodule import sys;而是作为一个文档工具。这种处理方式应该能让你更容易地记录模块的API,通过包含正确的__all__。它应该能帮助你避免错误地把sys引用为somemodule.sys,而是简单地使用import sys

3 个回答

3

其实你根本不需要什么导入钩子。简单来说,只要在Python的模块注册表 sys.modules 中用你想要的名字放一个对象的引用,以后在这个Python会话中(即使是在其他模块里)任何 import 语句都会导入这个对象的引用。

import types, sys

class MyModuleType(types.ModuleType):
    pass

sys.modules["foo"] = MyModuleType("foo")

import foo

print type(foo)   # MyModuleType

Python根本不在乎 sys.modules 中的对象到底是模块还是其他类型的对象。你甚至可以把一个 int 放进去,Python也不会在意。

sys.modules["answer"] = 42
import answer
3

你可以参考这个ActiveState的例子,可能可以做成这样:

# safemodule.py
import sys
import types
import warnings

class EncapsulationWarning(RuntimeWarning): pass

class ModuleWrapper(types.ModuleType):
    def __init__(self, context):
        self.context = context
        super(ModuleWrapper, self).__init__(
                context.__name__,
                context.__doc__)

    def __getattribute__(self, key):
        context = object.__getattribute__(self, 'context')
        if hasattr(context, '__all__') and key not in context.__all__:
            warnings.warn('%s not in %s.__all__' % (key, context.__name__),
                          EncapsulationWarning,
                          2)
        return context.__getattribute__(key)

if 'old_import' not in globals():
    old_import = __import__

    def safe_import(*args, **kwargs):
        m = old_import(*args, **kwargs)
        return ModuleWrapper(m)

    __builtins__['__import__'] = safe_import

然后,你可以像这样使用它:

C:\temp>python
ActivePython 2.5.2.2 (ActiveState Software Inc.) based on
Python 2.5.2 (r252:60911, Mar 27 2008, 17:57:18) [MSC v.1310 32 bit (Intel)] on
win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import safemodule
>>> import sys
>>> type(sys)
<class 'safemodule.ModuleWrapper'>
>>>

当然,你也可以根据需要调整,只包装某些模块等等。

5

因为我之前没理解你的问题,所以我想再试一次(同时保留我之前的回答以供参考)。

这里有一个替代方案,它不需要导入钩子。这个方法可以很方便地在每个模块中使用:包含这段代码的模块会有特别的 __getattribute__() 行为,而其他模块则会像往常一样工作。

class StrictModule(types.ModuleType):

    def __getattribute__(self, key):
        if key is "__dict__":  # fake __dict__ with only visible attributes
            return dict((k, v) for k, v in globals().iteritems()
                        if k.startswith("__") or k in __all__)
        if (key.startswith("__") or key in __all__) and key in globals():
            return globals()[key]
        else:
            raise AttributeError("'module' object has no attribute '%s'"  % key)

    def __setattr__(self, key, value):
        globals()[key] = value

sys.modules[__name__] = StrictModule(__name__)

请记住,这种“限制”其实很容易绕过,只需调用普通的 module 类型的 __getattribute__()(或者说 object 类型的)。我感觉你是想为你的模块提供某种“私有成员”的限制,但其实这样做几乎没有什么意义。

撰写回答