如何对数据库包装器进行单元测试?

2 投票
1 回答
1560 浏览
提问于 2025-04-18 02:52

我写了一个用于RethinkDB的数据库封装器,使用的是Python,这个封装器引入了模型(有点像Django中的模型和管理器)。我该如何为它编写单元测试呢?实际上,我该如何测试数据库是否已经更新了我给模型的值?

我想到了直接查询数据库,这确实可行,但这意味着我需要为每次测试运行都创建一个数据库连接(还要设置一个数据库)。有没有办法模拟一个数据库或连接,这样我就能让测试正常工作呢?

目前,我在测试的setUp()方法中创建一个连接对象,为测试创建一个数据库,然后在tearDown()方法中删除这些操作。

1 个回答

2

你可以使用 unittest.mock 这个工具来模拟你在实现包装器时使用的低级API,然后用断言来检查包装器对这个API的调用。

我对django模型或rethinkdb了解不多,但可能看起来像这样。

import uuid
import unittest
import unittest.mock as mock

import wrapper # your wrapper


class Person(wrapper.Model):
    name = wrapper.CharField()

class Tests(unittest.TestCase):
    def setUp(self):
        # you can mock the whole low-level API module
        self.mock_r = mock.Mock()
        self.r_patcher = mock.patch.dict('rethinkdb', rethinkdb=self.mock_r):
        self.r_patcher.start()
        wrapper.connect('localhost', 1234, 'db')

    def tearDown(self):
        self.r_patcher.stop()

    def test_create(self):
        """Test successful document creation."""
        id = uuid.uuid4()
        # rethinkdb.table('persons').insert(...) will return this
        self.mock_r.table().insert.return_value = {
            "deleted": 0,
            "errors": 0,
            "generated_keys": [
                id
            ],
            "inserted": 1,
            "replaced": 0,
            "skipped": 0,
            "unchanged": 0
        }

        name = 'Smith'
        person = wrapper.Person(name=name)
        person.save()

        # checking for a call like rethinkdb.table('persons').insert(...)
        self.mock_r.table.assert_called_once_with('persons')
        expected = {'name': name}
        self.mock_r.table().insert.assert_called_once_with(expected)

        # checking the generated id
        self.assertEqual(person.id, id)

    def test_create_error(self):
        """Test error during document creation."""
        error_msg = "boom!"
        self.mock_r.table().insert.return_value = {
            "deleted": 0,
            "errors": 1,
            "first_error": error_msg
            "inserted": 0,
            "replaced": 0,
            "skipped": 0,
            "unchanged": 0
        }

        name = 'Smith'
        person = wrapper.Person(name=name)

        # expecting error
        with self.assertRaises(wrapper.Error) as error:
            person.save()
        # checking the error message
        self.assertEqual(str(error), error_msg)

这段代码有点粗糙,但我希望你能明白我的意思。

编辑: 添加了 return_value 和错误测试。

撰写回答