如何使用mock库修补方法''.join

4 投票
5 回答
2697 浏览
提问于 2025-04-17 17:30

为了给某个函数创建单元测试,我需要对 ''.join(...) 进行修改。

我尝试了很多方法(使用 mock 库),但就是无法成功,尽管我在使用这个库创建单元测试方面有一些经验。

第一个问题是 str 是一个内置类,因此无法被修改。William John Bert 的一篇文章展示了如何处理这个问题(他处理的是 datetime.date)。在这个库的官方文档中,“部分修改”部分也提供了一个可能的解决方案。

第二个问题是 str 并不是直接使用的。实际上,是调用了字面量 ''join 方法。那么,我应该修改哪个路径呢?

以下这些选项都没有成功:

  • patch('__builtin__.str', 'join')
  • patch('string.join')
  • patch('__builtin__.str', FakeStr)(其中 FakeStrstr 的一个子类)

任何帮助都会非常感谢。

5 个回答

0

其实没有办法直接对字符串字面量进行这种操作,因为它们总是使用内置的 str 类,而这个类是无法通过这种方式进行修改的。

当然,你可以写一个函数 join(seq, sep=''),用它来替代 ''.join(),然后对这个函数进行修改;或者你可以创建一个 str 的子类,叫做 Separator,专门用来构造那些需要进行 join 操作的字符串(比如 Separator('').join(....))。这些变通方法虽然有点麻烦,但除此之外你就没法修改这个方法了。

3

如果你觉得自己运气特别好,可以去查看并修改代码中的字符串常量:

def patch_strings(fun, cls):
    new_consts = tuple(
                  cls(c) if type(c) is str else c
                  for c in fun.func_code.co_consts)

    code = type(fun.func_code)

    fun.func_code = code(
           fun.func_code.co_argcount,
           fun.func_code.co_nlocals, 
           fun.func_code.co_stacksize,
           fun.func_code.co_flags,
           fun.func_code.co_code,
           new_consts,
           fun.func_code.co_names,
           fun.func_code.co_varnames,
           fun.func_code.co_filename,
           fun.func_code.co_name,
           fun.func_code.co_firstlineno,
           fun.func_code.co_lnotab,
           fun.func_code.co_freevars,
           fun.func_code.co_cellvars)

def a():
    return ''.join(['a', 'b'])

class mystr(str):
    def join(self, s):
        print 'join called!'
        return super(mystr, self).join(s)

patch_strings(a, mystr)
print a()      # prints "join called!\nab"

这是Python3版本:

def patch_strings(fun, cls):
    new_consts = tuple(
                   cls(c) if type(c) is str else c
                   for c in fun.__code__.co_consts)

    code = type(fun.__code__)

    fun.__code__ = code(
           fun.__code__.co_argcount,
           fun.__code__.co_kwonlyargcount,
           fun.__code__.co_nlocals, 
           fun.__code__.co_stacksize,
           fun.__code__.co_flags,
           fun.__code__.co_code,
           new_consts,
           fun.__code__.co_names,
           fun.__code__.co_varnames,
           fun.__code__.co_filename,
           fun.__code__.co_name,
           fun.__code__.co_firstlineno,
           fun.__code__.co_lnotab,
           fun.__code__.co_freevars,
           fun.__code__.co_cellvars)
4

你不能这样做,因为内置类是无法设置属性的:

>>> str.join = lambda x: None
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can't set attributes of built-in/extension type 'str'

而且你也不能修改 str,因为 ''.join 使用的是一个字面量,所以解释器总是会创建一个 str,无论你怎么尝试去替换 str__builtin__ 中。

如果你查看生成的字节码,就能明白这一点:

>>> import dis
>>> def test():
...     ''.join([1,2,3])
... 
>>> dis.dis(test)
  2           0 LOAD_CONST               1 ('')
              3 LOAD_ATTR                0 (join)
              6 LOAD_CONST               2 (1)
              9 LOAD_CONST               3 (2)
             12 LOAD_CONST               4 (3)
             15 BUILD_LIST               3
             18 CALL_FUNCTION            1
             21 POP_TOP             
             22 LOAD_CONST               0 (None)
             25 RETURN_VALUE

字节码是在编译时生成的,正如你所看到的,第一个 LOAD_CONST 加载的是 '',这始终是一个 str,无论你在运行时如何改变 str 的值。

可以做的是使用一个包装函数,这样可以被模拟,或者避免使用字面量。例如,使用 str() 代替 '',这样你就可以用一个子类来模拟 str 类,并按照你想要的方式实现 join 方法(尽管这可能会影响很多代码,并且根据你使用的模块可能不太可行)。

撰写回答