如何模拟Python函数以避免在导入时被调用?
我正在为别人的代码编写单元测试(使用pytest),但我不能以任何方式修改或更改这段代码。这段代码有一个全局变量,它在任何函数外部通过一个函数的返回值进行初始化,并且这个函数在本地运行时会报错。我不能分享那段代码,但我写了一个简单的文件,遇到了同样的问题:
def annoying_function():
'''Does something that generates exception due to some hardcoded cloud stuff'''
raise ValueError() # Simulate the original function raising error due to no cloud connection
annoying_variable = annoying_function()
def normal_function():
'''Works fine by itself'''
return True
这是我的测试函数:
def test_normal_function():
from app.annoying_file import normal_function
assert normal_function() == True
这个测试失败了,因为annoying_function
抛出了ValueError
,因为它在模块导入时仍然被调用。
以下是错误堆栈信息:
failed: def test_normal_function():
> from app.annoying_file import normal_function
test\test_annoying_file.py:6:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
app\annoying_file.py:6: in <module>
annoying_variable = annoying_function()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def annoying_function():
'''Does something that generates exception due to some hardcoded cloud stuff'''
> raise ValueError()
E ValueError
app\annoying_file.py:3: ValueError
我尝试像这样模拟这个annoying_function
:
def test_normal_function(mocker):
mocker.patch("app.annoying_file.annoying_function", return_value="foo")
from app.annoying_file import normal_function
assert normal_function() == True
但结果还是一样。
以下是错误堆栈信息:
failed: thing = <module 'app' (<_frozen_importlib_external._NamespaceLoader object at 0x00000244A7C72FE0>)>
comp = 'annoying_file', import_path = 'app.annoying_file'
def _dot_lookup(thing, comp, import_path):
try:
> return getattr(thing, comp)
E AttributeError: module 'app' has no attribute 'annoying_file'
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1238: AttributeError
During handling of the above exception, another exception occurred:
mocker = <pytest_mock.plugin.MockerFixture object at 0x00000244A7C72380>
def test_normal_function(mocker):
> mocker.patch("app.annoying_file.annoying_function", return_value="foo")
test\test_annoying_file.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.venv\lib\site-packages\pytest_mock\plugin.py:440: in __call__
return self._start_patch(
.venv\lib\site-packages\pytest_mock\plugin.py:258: in _start_patch
mocked: MockType = p.start()
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1585: in start
result = self.__enter__()
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1421: in __enter__
self.target = self.getter()
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1608: in <lambda>
getter = lambda: _importer(target)
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1251: in _importer
thing = _dot_lookup(thing, comp, import_path)
..\..\..\..\.pyenv\pyenv-win\versions\3.10.5\lib\unittest\mock.py:1240: in _dot_lookup
__import__(import_path)
app\annoying_file.py:6: in <module>
annoying_variable = annoying_function()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def annoying_function():
'''Does something that generates exception due to some hardcoded cloud stuff'''
> raise ValueError()
E ValueError
app\annoying_file.py:3: ValueError
而且移动导入语句的位置也没有影响我的结果。
根据我所了解的,这种情况发生是因为模拟工具(我使用的是pytest-mock)需要导入包含要模拟的函数的文件,而在导入这个文件时,annoying_variable = annoying_function()
这一行会执行,结果导致模拟过程失败。
我找到的唯一能让这个问题部分解决的方法是模拟导致原始代码出错的云服务部分,但我想避免这样做,因为那样我的测试就不再是单元测试了。
再次强调,我不能修改或更改原始代码。任何想法或建议我都非常感激。
1 个回答
2
正如其他评论者提到的,你试图解决的问题其实反映了代码中更大的问题,所以直接给出解决这个具体问题的答案可能不是最好的办法。不过,下面有一种不太常规且有点乱的方法可以尝试。这种方法基于以下几个想法:
- 在
annoying_file.py
被导入之前,动态调整它的源代码,这样annoying_function()
就不会被调用。根据你的代码示例,我们可以通过把annoying_variable = annoying_function()
替换成annoying_variable = None
来实现这个目标,注意这是在实际源代码中进行的。 - 导入经过动态调整的模块,而不是原始模块。
- 在这个经过动态调整的模块中测试
normal_function()
。
在下面的代码中,我假设:
- 有一个名为
annoying_file.py
的模块,里面包含了你提到的annoying_function()
、annoying_variable
和normal_function()
。 annoying_file.py
和包含下面代码的模块在同一个文件夹里。
from ast import parse, unparse, Assign, Constant
from importlib.abc import SourceLoader
from importlib.util import module_from_spec, spec_from_loader
def patch_annoying_variable_in(module_name: str) -> str:
"""Return patched source code, where `annoying_variable = None`"""
with open(f"{module_name}.py", mode="r") as f:
tree = parse(f.read())
for stmt in tree.body:
# Assign None to `annoying_variable`
if (isinstance(stmt, Assign) and len(stmt.targets) == 1
and stmt.targets[0].id == "annoying_variable"):
stmt.value = Constant(value=None)
break
return unparse(tree)
def import_from(module_name: str, source_code: str):
"""Load and return a module that has the given name and holds the given code."""
# Following https://stackoverflow.com/questions/62294877/
class SourceStringLoader(SourceLoader):
def get_data(self, path): return source_code.encode("utf-8")
def get_filename(self, fullname): return f"{module_name}.py (patched)"
loader = SourceStringLoader()
mod = module_from_spec(spec_from_loader(module_name, loader))
loader.exec_module(mod)
return mod
def test_normal_function():
module_name = "annoying_file"
patched_code = patch_annoying_variable_in(module_name)
mod = import_from(module_name, patched_code)
assert mod.normal_function() == True
这段代码实现了以下功能:
- 通过
patch_annoying_variable_in()
,原始的annoying_file
代码被解析。对annoying_variable
的赋值被替换,这样annoying_function()
就不会被执行。最终得到的调整后的源代码会被返回。 - 通过
import_from()
,加载调整后的源代码作为一个模块。 - 最后,
test_normal_function()
利用前面两个函数来测试这个动态调整的模块。