为类装饰器添加类型注解
我正在写一些micropython代码(micropython没有数据类),目的是创建一个可以被序列化和反序列化的通用类。我写了一个装饰器函数,用来给这个类添加两个功能:
import pickle
from typing import Any
from abc import ABCMeta, abstractmethod
class Serializable(ABCMeta):
@abstractmethod
def serialize(self) -> str:
pass
@staticmethod
@abstractmethod
def deserialize(data: str) -> 'Serializable':
pass
def serializable(cls):
if not hasattr(cls, 'fields'):
raise ValueError('fields is required')
def __init__(self, **kwargs: Any) -> None:
for fname, _ in cls.fields:
if fname not in kwargs:
raise ValueError(f'{fname} is required')
setattr(self, fname, kwargs.get(fname))
cls.__init__ = __init__
def serialize(self: Serializable) -> str:
return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
cls.serialize = serialize
def deserialize(data: str) -> Serializable:
parts = data.split(' ')
if parts[0] != cls.__name__:
raise ValueError(f'invalid data not a {cls.__name__}')
if len(parts) - 1 != len(cls.fields):
raise ValueError('invalid data not enough fields')
return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
cls.deserialize = deserialize
return cls
因为我们在项目中使用严格的类型检查工具mypy,所以它会抱怨一些函数(比如这个装饰器)没有经过类型检查。此外,当我这样使用这个类时,它也会发出警告:
@serializable
class Test:
fields = [
('a', int),
('b', str)
]
x = Test(a=1, b='hello')
print(x.a)
我该如何让mypy正确地进行类型检查呢?
1 个回答
1
正如评论中提到的,dataclass_transform
是你需要的东西,不过你需要用 __annotations__
来代替你例子中的自定义 fields
。好消息是,这样写其实更简单:
from typing import Any, dataclass_transform
from abc import ABCMeta, abstractmethod
class Serializable(ABCMeta):
@abstractmethod
def serialize(self) -> str:
pass
@staticmethod
@abstractmethod
def deserialize(data: str) -> 'Serializable':
pass
@dataclass_transform()
def serializable(cls):
def __init__(self, **kwargs: Any) -> None:
for fname, _ in cls.__annotations__.items():
if fname not in kwargs:
raise ValueError(f'{fname} is required')
setattr(self, fname, kwargs[fname])
cls.__init__ = __init__
def serialize(self: Serializable) -> str:
return f"{cls.__name__} {' '.join(str(getattr(self, fname)) for fname, _ in cls.fields)}"
cls.serialize = serialize
def deserialize(data: str) -> Serializable:
parts = data.split(' ')
if parts[0] != cls.__name__:
raise ValueError(f'invalid data not a {cls.__name__}')
if len(parts) - 1 != len(cls.fields):
raise ValueError('invalid data not enough fields')
return cls(**{fname: ftype(parts[i+1]) for i, (fname, ftype) in enumerate(cls.fields)})
cls.deserialize = deserialize
return cls
@serializable
class Test:
a: int
b: str
x = Test(a=1, b='hello')
print(x.a)