如何在Python中编写自定义的`.assertFoo()`方法?

48 投票
3 回答
13950 浏览
提问于 2025-04-16 21:17

我正在用Python的unittest写一些测试用例。现在我需要比较两个对象列表,看看第一个列表中的对象是否符合我的预期。

我该如何写一个自定义的.assertFoo()方法呢?这个方法应该做些什么?如果比较失败,它应该抛出异常吗?如果是的话,抛出什么类型的异常比较好?怎么传递错误信息?错误信息是用unicode字符串好,还是用字节字符串好呢?

可惜的是,官方文档并没有解释如何写自定义的断言方法。

如果你需要一个实际的例子,继续往下看。


我写的代码大致是这样的:

def assert_object_list(self, objs, expected):
    for index, (item, values) in enumerate(zip(objs, expected)):
        self.assertEqual(
            item.foo, values[0],
            'Item {0}: {1} != {2}'.format(index, item.foo, values[0])
        )
        self.assertEqual(
            item.bar, values[1],
            'Item {0}: {1} != {2}'.format(index, item.bar, values[1])
        )

def test_foobar(self):
    objs = [...]  # Some processing here
    expected = [
        # Expected values for ".foo" and ".bar" for each object
        (1, 'something'),
        (2, 'nothing'),
    ]
    self.assert_object_list(objs, expected)

这种方法让我们可以非常简洁地描述每个对象的预期值,而且不需要真正创建完整的对象。

但是……当一个对象的比较失败时,后面的对象就不会再比较了,这让调试变得有点困难。我想写一个自定义的方法,可以无条件地比较所有对象,然后显示所有失败的对象,而不仅仅是第一个。

3 个回答

3

这里有一个例子,用来总结如何使用numpy进行比较的单元测试。

import numpy as np
class CustomTestCase(unittest.TestCase):
    def npAssertAlmostEqual(self, first, second, rtol=1e-06, atol=1e-08):
        np.testing.assert_allclose(first, second, rtol=rtol, atol=atol)


class TestVector(CustomTestCase):
    def testFunction(self):
        vx = np.random.rand(size)
        vy = np.random.rand(size)
        self.npAssertAlmostEqual(vx, vy)
19

你应该创建一个自己的测试类,这个类要从unittest.TestCase继承。然后把你自定义的断言方法放到这个测试类里。如果你的测试失败了,就抛出一个AssertionError(断言错误)。这个错误信息应该是一个字符串。如果你想测试列表中的所有对象,而不是在遇到失败时就停止,那么可以收集所有失败的索引,等遍历完所有对象后,再构建一个总结你发现的错误信息的断言。

46

我在这些情况下使用多重继承。举个例子:

首先,我定义一个包含一些方法的类,这些方法会被其他类使用。

import os

class CustomAssertions:
    def assertFileExists(self, path):
        if not os.path.lexists(path):
            raise AssertionError('File not exists in path "' + path + '".')

现在,我定义一个类,它同时继承自unittest.TestCase和CustomAssertion这两个类。

import unittest

class MyTest(unittest.TestCase, CustomAssertions):
    def test_file_exists(self):
        self.assertFileExists('any/file/path')

if __name__ == '__main__':
    unittest.main()

撰写回答