ContextManager使用b访问调用的locals()

2024-04-19 10:15:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试用上下文管理器编写一个多线程助手。其思想是在一个块中定义一组函数,上下文管理器“神奇地”处理调度和所有事情。简化的工作版本如下所示:

import contextlib

@contextlib.contextmanager
def multi_threaded(count):
    funcs = []
    yield funcs
    my_slice = int(count / len(funcs))
    for i, func in enumerate(funcs):
        start = my_slice * i
        func(start, start + my_slice)   


def spawn_many():
    dataset = [1, 2, 3, 4, 5]
    with multi_threaded(len(dataset)) as mt:
        def foo(start_idx, end):
            print("foo" + str(dataset[start_idx : end]))
        def bar(start_idx, end):
            print("bar" + str(dataset[start_idx : end]))
        mt.append(foo)
        mt.append(bar)

spawn_many()

这个例子是可行的,但我想去掉这些行:

        mt.append(foo)
        mt.append(bar)

这样用户只需要定义函数,而不需要将它们添加到集合中。为什么?因为它不太容易出错,而且我无法控制用这个库编写的代码。你知道吗

问题是,在产生之后,我不在发生def foo的范围内,所以我不知道该范围内存在的locals(),这基本上就是我需要知道在那里定义了哪些函数。有什么想法/窍门/鼓励的话吗?你知道吗

感谢阅读!你知道吗


Tags: 函数管理器定义foomydefbarslice
2条回答

我读到this is not possible,至少不是没有丑陋的黑客,但我认为我的解决方案最终不是那么丑陋:

在创建时将locals()字典传递到contextmanager,contextmanager在屈服后询问该字典,以收集任何可调用项:

@contextlib.contextmanager
def multi_threaded(block_locals, count):
    yield

    funcs = [fn for fn in block_locals.values() if callable(fn)]

    my_slice = int(count / len(funcs))
    for i, func in enumerate(funcs):
        start = my_slice * i
        func(start, start + my_slice)   

def spawn_many():
    dataset = [1, 2, 3, 4, 5]
    with multi_threaded(locals(), len(dataset)):
        def foo(start_idx, end):
            print("foo" + str(dataset[start_idx : end]))
        def bar(start_idx, end):
            print("bar" + str(dataset[start_idx : end]))

        # Re-sync locals-dict handed earlier to multi_threaded().
        locals()

spawn_many()

注意,这个技巧之所以有效,是因为最后一次调用块中的locals()。Python似乎只在调用locals()时才同步locals()-dictionary<;gt;函数局部变量。如果没有最后一个调用,multi_threaded会将{'dataset': [1, 2, 3, 4, 5]}视为局部变量。你知道吗

装潢师可能会好一点:

import contextlib

@contextlib.contextmanager
def multi_threaded(count):
    funcs = []
    yield funcs
    my_slice = int(count / len(funcs))
    for i, func in enumerate(funcs):
        start = my_slice * i
        func(start, start + my_slice)   

def add_to_flist(mt):
    def _add_to_flist(func):
        mt.append(func)
        return func
    return _add_to_flist

def spawn_many():
    dataset = [1, 2, 3, 4, 5]
    with multi_threaded(len(dataset)) as mt:
        @add_to_flist(mt)
        def foo(start_idx, end):
            print("foo" + str(dataset[start_idx : end]))
        @add_to_flist(mt)
        def bar(start_idx, end):
            print("bar" + str(dataset[start_idx : end]))

spawn_many()

相关问题 更多 >