为类装饰器添加类型注解

2 投票
1 回答
109 浏览
提问于 2025-04-14 16:58

我正在写一些micropython代码(micropython没有数据类),目的是创建一个可以被序列化和反序列化的通用类。我写了一个装饰器函数,用来给这个类添加两个功能:

import pickle
from typing import Any
from abc import ABCMeta, abstractmethod

class Serializable(ABCMeta):
    @abstractmethod
    def serialize(self) -> str:
        pass

    @staticmethod
    @abstractmethod
    def deserialize(data: str) -> 'Serializable':
        pass


def serializable(cls):
    if not hasattr(cls, 'fields'):
        raise ValueError('fields is required')
    def __init__(self, **kwargs: Any) -> None:
        for fname, _ in cls.fields:
            if fname not in kwargs:
                raise ValueError(f'{fname} is required')
            setattr(self, fname, kwargs.get(fname))
    cls.__init__ = __init__

    def serialize(self: Serializable) -> str:
        return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
    cls.serialize = serialize

    def deserialize(data: str) -> Serializable:
        parts = data.split(' ')
        if parts[0] != cls.__name__:
            raise ValueError(f'invalid data not a {cls.__name__}')
        if len(parts) - 1 != len(cls.fields):
            raise ValueError('invalid data not enough fields')
        return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
    cls.deserialize = deserialize

    return cls

因为我们在项目中使用严格的类型检查工具mypy,所以它会抱怨一些函数(比如这个装饰器)没有经过类型检查。此外,当我这样使用这个类时,它也会发出警告:

@serializable
class Test:
    fields = [
        ('a', int),
        ('b', str)
    ]

x = Test(a=1, b='hello')
print(x.a)

我该如何让mypy正确地进行类型检查呢?

1 个回答

1

正如评论中提到的,dataclass_transform 是你需要的东西,不过你需要用 __annotations__ 来代替你例子中的自定义 fields。好消息是,这样写其实更简单:

from typing import Any, dataclass_transform
from abc import ABCMeta, abstractmethod

class Serializable(ABCMeta):
    @abstractmethod
    def serialize(self) -> str:
        pass

    @staticmethod
    @abstractmethod
    def deserialize(data: str) -> 'Serializable':
        pass

@dataclass_transform()
def serializable(cls):
    def __init__(self, **kwargs: Any) -> None:
        for fname, _ in cls.__annotations__.items():
            if fname not in kwargs:
                raise ValueError(f'{fname} is required')
            setattr(self, fname, kwargs[fname])
    cls.__init__ = __init__

    def serialize(self: Serializable) -> str:
        return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
    cls.serialize = serialize

    def deserialize(data: str) -> Serializable:
        parts = data.split(' ')
        if parts[0] != cls.__name__:
            raise ValueError(f'invalid data not a {cls.__name__}')
        if len(parts) - 1 != len(cls.fields):
            raise ValueError('invalid data not enough fields')
        return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
    cls.deserialize = deserialize

    return cls

@serializable
class Test:
    a: int
    b: str

x = Test(a=1, b='hello')
print(x.a)

撰写回答