创建自定义 JSONEncoder

8 投票
2 回答
5716 浏览
提问于 2025-04-16 19:55

我正在使用Python 2.7,想要创建一个自定义的FloatEncoder类,这个类是JSONEncoder的一个子类。我参考了很多例子,比如这个,但是都没有成功。以下是我的FloatEncoder类:

class FloatEncoder(JSONEncoder):
    def _iterencode(self, obj, markers=None):
         if isinstance(obj, float):
            return (str(obj) for obj in [obj])
        return super(FloatEncoder, self)._iterencode(obj, markers)

这是我调用json.dumps的地方:

with patch("utils.fileio.FloatEncoder") as float_patch:
        for val,res in ((.00123456,'0.0012'),(.00009,'0.0001'),(0.99999,'1.0000'),({'hello':1.00001,'world':[True,1.00009]},'{"world": [true, 1.0001], "hello": 1.0000}')): 
            untrusted = dumps(val, cls=FloatEncoder)
            self.assertTrue(float_patch._iterencode.called)
            self.assertEqual(untrusted, res)

第一个断言失败了,这意味着_my_iterencode没有被执行。在阅读了JSON的文档后,我尝试重写default()方法,但这个方法也没有被调用。

2 个回答

0

不要定义 _iterencode,应该定义 default,就像那页第三个答案里说的那样。

2

你似乎是在尝试将浮点数值四舍五入到小数点后四位,然后生成JSON格式的数据(根据测试示例来看)。

Python 2.7自带的JSONEncoder没有_iterencode这个方法,所以它不会被调用。而且快速查看一下json/encoder.py文件会发现,这个类的写法让它很难改变浮点数的编码方式。也许,分开处理会更好,先把浮点数四舍五入,然后再进行JSON序列化。

补充说明:Alex Martelli在一个相关的回答中提供了一种猴子补丁的解决方案。这个方法的问题在于,你会对json库的行为进行全局修改,这可能会无意中影响到你应用中其他部分的代码,因为那些代码是基于浮点数没有四舍五入的假设来写的。

试试这个:

from collections import Mapping, Sequence
from unittest import TestCase, main
from json import dumps

def round_floats(o):
    if isinstance(o, float):
        return round(o, 4)
    elif isinstance(o, basestring):
        return o
    elif isinstance(o, Sequence):
        return [round_floats(item) for item in o]
    elif isinstance(o, Mapping):
        return dict((key, round_floats(value)) for key, value in o.iteritems())
    else:
        return o

class TestFoo(TestCase):
    def test_it(self):
        for val, res in ((.00123456, '0.0012'),
                         (.00009, '0.0001'),
                         (0.99999, '1.0'),
                         ({'hello': 1.00001, 'world': [True, 1.00009]},
                          '{"world": [true, 1.0001], "hello": 1.0}')):
            untrusted = dumps(round_floats(val))
            self.assertEqual(untrusted, res)

if __name__ == '__main__':
    main()

撰写回答