在mock.assert_has_calls中比较生成器对象

4 投票
1 回答
3980 浏览
提问于 2025-04-28 04:27

我正在为一个函数写单元测试,目的是确认这个函数内部调用的另一个函数是否使用了正确的参数。问题是,其中一个参数是一个生成器。

有没有办法用 assert_has_calls 来比较传给 fn 的生成器的内容?我想要的效果可以在 'AssertSequenceEqual' 中看到。现在,test_use_fn 失败了,因为它比较的生成器对象是不同的。

import mock

def fn(entries):
    pass

def use_fn(entries, convert=lambda x: x):
    entries = (convert(entry) for entry in entries)
    entries = fn(entries)
    entries = fn(entries)

@mock.patch('fn')
def test_use_fn(self, mock_fn):
    mock_fn.return_value = 'baz'
    entries = ['foo', 'bar']
    use_fn(entries)
    call_1 = mock.call((entry for entry in entries))
    call_2 = mock.call('baz')
    mock_fn.assert_has_calls([call_1, call_2])
暂无标签

1 个回答

3

你可以使用 call_args_list 来查看函数被调用时的参数列表,具体可以参考这个链接:https://docs.python.org/3/library/unittest.mock.html#unittest.mock.Mock.call_args_list

我假设你会检查生成器中的每个项目是否相同。我写了一个 assertEqualGenerators() 方法,放在测试用例类里来完成这个工作(如果参数不是生成器,就用标准的 assertEqual())。这个模块的文件名是 mock_generators.py,你需要用 mock_generators.fn 来替换 fn。最后一个小技巧是关于 call 对象的参数:你可以查看这个链接了解如何获取数据的细节:https://docs.python.org/3/library/unittest.mock.html#unittest.mock.call.call_list(在你的情况下,第一个元素就是你需要的)。

import unittest
from unittest import mock

def fn(entries):
    pass

def use_fn(entries, convert=lambda x: x):
    entries = (convert(entry) for entry in entries)
    entries = fn(entries)
    entries = fn(entries)

class MyTestCase(unittest.TestCase):

    def assertEqualGenerators(self,a,b):
        try:
            for x,y in zip(a,b):
                self.assertEqual(x, y)
        except TypeError:
            self.assertEqual(a, b)

    @mock.patch("mock_generators.fn")
    def test_use_fn(self, mock_fn):
        mock_fn.return_value = 'baz'
        entries = ['foo', 'bar']
        use_fn(entries)
        calls = [mock.call((entry for entry in entries)),
                    mock.call('baz')]
        self.assertEqual(len(calls), mock_fn.call_count)
        for a,b in zip(mock_fn.call_args_list,calls):
            self.assertEqualGenerators(a[0],b[0])


if __name__ == '__main__':
    unittest.main()

撰写回答