如何在Python中发现特定包中的类?
我有一组插件风格的模块,长得像这样:
/Plugins /Plugins/__init__.py /Plugins/Plugin1.py /Plugins/Plugin2.py etc...
每个 .py 文件里都有一个类,这个类是从 PluginBaseClass
这个基类派生出来的。所以我需要列出 Plugins
这个包里的每个模块,然后找出所有实现了 PluginBaseClass
的类。理想情况下,我想做到类似这样的事情:
for klass in iter_plugins(project.Plugins):
action = klass()
action.run()
我看到过一些其他的回答,但我的情况有点不同。我实际上是导入了基础包(也就是 import project.Plugins
),然后我需要在发现模块之后找到这些类。
4 个回答
6
你可能应该在 __init__.py
文件中定义一个叫 __all__
的列表,这个列表里包含你包里的子模块。这样做是为了支持别人使用 from Plugins import *
这种方式来导入模块。如果你这样做了,就可以用下面的方式来遍历这些模块:
import Plugins
import sys
modules = { }
for module in Plugins.__all__:
__import__( module )
modules[ module ] = sys.modules[ module ]
# iterate over dir( module ) as above
这里有个原因,为什么其他回答会失败,那是因为 __import__
这个函数会导入最底层的模块,但它返回的是最顶层的模块(具体可以查看 文档)。我也不知道为什么会这样。
7
编辑:这是一个修订后的解决方案。我意识到在测试我之前的方案时犯了个错误,它并没有像你想的那样工作。所以这里是一个更完整的解决方案:
import os
from imp import find_module
from types import ModuleType, ClassType
def iter_plugins(package):
"""Receives package (as a string) and, for all of its contained modules,
generates all classes that are subclasses of PluginBaseClass."""
# Despite the function name, "find_module" will find the package
# (the "filename" part of the return value will be None, in this case)
filename, path, description = find_module(package)
# dir(some_package) will not list the modules within the package,
# so we explicitly look for files. If you need to recursively descend
# a directory tree, you can adapt this to use os.walk instead of os.listdir
modules = sorted(set(i.partition('.')[0]
for i in os.listdir(path)
if i.endswith(('.py', '.pyc', '.pyo'))
and not i.startswith('__init__.py')))
pkg = __import__(package, fromlist=modules)
for m in modules:
module = getattr(pkg, m)
if type(module) == ModuleType:
for c in dir(module):
klass = getattr(module, c)
if (type(klass) == ClassType and
klass is not PluginBaseClass and
issubclass(klass, PluginBaseClass)):
yield klass
我之前的解决方案是:
你可以试试这样的:
from types import ModuleType
import Plugins
classes = []
for item in dir(Plugins):
module = getattr(Plugins, item)
# Get all (and only) modules in Plugins
if type(module) == ModuleType:
for c in dir(module):
klass = getattr(module, c)
if isinstance(klass, PluginBaseClass):
classes.append(klass)
其实,更好的办法是,如果你想要一些模块化的设计:
from types import ModuleType
def iter_plugins(package):
# This assumes "package" is a package name.
# If it's the package itself, you can remove this __import__
pkg = __import__(package)
for item in dir(pkg):
module = getattr(pkg, item)
if type(module) == ModuleType:
for c in dir(module):
klass = getattr(module, c)
if issubclass(klass, PluginBaseClass):
yield klass
2
扫描模块并不是个好主意。如果你需要类的注册功能,建议你看看元类,或者使用一些现成的解决方案,比如zope.interface。通过元类实现的简单解决方案可能看起来像这样:
from functools import reduce
class DerivationRegistry(type):
def __init__(cls,name,bases,cls_dict):
type.__init__(cls,name,bases,cls_dict)
cls._subclasses = set()
for base in bases:
if isinstance(base,DerivationRegistry):
base._subclasses.add(cls)
def getSubclasses(cls):
return reduce( set.union,
( succ.getSubclasses() for succ in cls._subclasses if isinstance(succ,DerivationRegistry)),
cls._subclasses)
class Base(object):
__metaclass__ = DerivationRegistry
class Cls1(object):
pass
class Cls2(Base):
pass
class Cls3(Cls2,Cls1):
pass
class Cls4(Cls3):
pass
print(Base.getSubclasses())