如何在Python unittest框架中简洁实现多个相似的单元测试?

15 投票
9 回答
8171 浏览
提问于 2025-04-11 20:01

我正在为一组函数编写单元测试,这些函数都有一些共同的特性。比如,当用两个矩阵调用这个函数时,它会返回一个形状已知的矩阵。

我希望能写一个测试,来检查这一组函数是否都符合这个特性,而不需要为每个函数单独写一个测试(特别是因为以后可能还会添加更多的函数)。

一种方法是遍历这些函数的列表:

import unittest
import numpy

from somewhere import the_functions
from somewhere.else import TheClass

class Test_the_functions(unittest.TestCase):
  def setUp(self):
    self.matrix1 = numpy.ones((5,10))
    self.matrix2 = numpy.identity(5)

  def testOutputShape(unittest.TestCase):
     """Output of functions be of a certain shape"""
     for function in all_functions:
       output = function(self.matrix1, self.matrix2)
       fail_message = "%s produces output of the wrong shape" % str(function)
       self.assertEqual(self.matrix1.shape, output.shape, fail_message)

if __name__ == "__main__":
  unittest.main()

这个想法是从Dive Into Python上得到的。在那里,测试的不是函数列表,而是已知的输入输出对。不过,这种方法的问题在于,如果列表中的任何一个元素测试失败了,后面的元素就不会被测试了。

我考虑过通过子类化unittest.TestCase来实现,并以某种方式将要测试的具体函数作为参数传入,但据我所知,这样做会导致我们无法使用unittest.main(),因为没有办法将参数传递给测试用例。

我还尝试过使用setattr和lambda动态地将“testSomething”函数附加到测试用例上,但测试用例并没有识别它们。

我该如何重写这个代码,以便在保持测试列表易于扩展的同时,确保每个测试都能运行呢?

9 个回答

6

这里其实不需要用到 元类。一个简单的循环就可以了。看看下面的例子:

import unittest

class TestCase1(unittest.TestCase):
    def check_something(self, param1):
        self.assertTrue(param1)

def _add_test(name, param1):
    def test_method(self):
        self.check_something(param1)
    setattr(TestCase1, 'test_' + name, test_method)
    test_method.__name__ = 'test_' + name
    
for i in range(0, 3):
    _add_test(str(i), False)

for 循环执行完后,TestCase1 里会有三个测试方法,这些方法可以同时被 noseunittest 支持。

11

这是我最喜欢的处理“相关测试家族”的方法。我喜欢用明确的子类来表示共同的特性。

class MyTestF1( unittest.TestCase ):
    theFunction= staticmethod( f1 )
    def setUp(self):
        self.matrix1 = numpy.ones((5,10))
        self.matrix2 = numpy.identity(5)
    def testOutputShape( self ):
        """Output of functions be of a certain shape"""
        output = self.theFunction(self.matrix1, self.matrix2)
        fail_message = "%s produces output of the wrong shape" % (self.theFunction.__name__,)
        self.assertEqual(self.matrix1.shape, output.shape, fail_message)

class TestF2( MyTestF1 ):
    """Includes ALL of TestF1 tests, plus a new test."""
    theFunction= staticmethod( f2 )
    def testUniqueFeature( self ):
         # blah blah blah
         pass

class TestF3( MyTestF1 ):
    """Includes ALL of TestF1 tests with no additional code."""
    theFunction= staticmethod( f3 )

你只需要添加一个函数,创建一个 MyTestF1 的子类。每个 MyTestF1 的子类都包含了 MyTestF1 中的所有测试,而且没有任何重复的代码。

独特的特性可以通过很明显的方式来处理。你只需在子类中添加新的方法。

这个方法和 unittest.main() 完全兼容。

5

你可以使用元类来动态地插入测试。这对我来说效果很好:

import unittest

class UnderTest(object):

    def f1(self, i):
        return i + 1

    def f2(self, i):
        return i + 2

class TestMeta(type):

    def __new__(cls, name, bases, attrs):
        funcs = [t for t in dir(UnderTest) if t[0] == 'f']

        def doTest(t):
            def f(slf):
                ut=UnderTest()
                getattr(ut, t)(3)
            return f

        for f in funcs:
            attrs['test_gen_' + f] = doTest(f)
        return type.__new__(cls, name, bases, attrs)

class T(unittest.TestCase):

    __metaclass__ = TestMeta

    def testOne(self):
        self.assertTrue(True)

if __name__ == '__main__':
    unittest.main()

撰写回答