枚举所有可能的数据类实例(仅限enum和bool字段)

2024-06-01 00:53:30 发布

您现在位置:Python中文网/ 问答频道 /正文

基本上我需要以下。我有一个python3dataclassNamedTuple,只有enumbool字段。例如:

from enum import Enum, auto
from typing import NamedTuple

class MyEnum(Enum):
    v1 = auto()
    v2 = auto()
    v3 = auto()


class MyStateDefinition(NamedTuple):
    a: MyEnum
    b: bool

这里有什么好的解决方案可以枚举这样一个数据类的所有可能的不相等实例吗(上面的示例有6个可能的不相等实例)

也许我不应该使用dataclass,而是其他东西。或者我应该直接玩像^{}这样的东西吗

我把它想象成某种表生成器,它接受namedtupledataclass作为输入参数,并生成所有可能的值

table = DataTable(MyStateDefinition)
for item in table:
    # Use items somehow
    print(item.a)
    print(item.b)

为什么我需要它?我只是有一些由枚举和布尔组成的状态定义。我相信它可以实现为位掩码。但是当涉及到用新值扩展你的位掩码时,它变成了一场噩梦。毕竟,比特面具似乎是一种非Python式的做事方式

目前,我必须使用自己的实现。但也许我正在重新发明轮子

谢谢


Tags: 实例fromimportautotableenumitemnamedtuple
2条回答

也张贴了我自己的实现。不太理想,我不得不使用一些受保护的成员

用法:

from typing import NamedTuple
from datatable import DataTable

class BoolsEndEnums(NamedTuple):
    a: E
    b: bool


tbl = DataTable(BoolsEndEnums)

item = tbl[0]

print(item.a) # a is v1
print(item.b) # b is False

有关更多用法示例,请参见test_datatable.py_test_cls

datatable.py

import collections
import dataclasses
from collections import Iterable
from enum import Enum
from typing import Union, Any, Tuple, Iterator, get_type_hints, NamedTuple


def is_cls_namedtuple(cls):
    return issubclass(cls, tuple) and hasattr(cls, "_fields")


class DataTable(Iterable):
    def __init__(self, data_cls):
        self._table = []
        self._index = {}
        self._rindex = {}
        self._named_tuple_cls = None

        fields = None

        if dataclasses.is_dataclass(data_cls):
            fields = [f.name for f in dataclasses.fields(data_cls)]
            self._named_tuple_cls = collections.namedtuple(
                f"{data_cls.__name__}_immutable",
                fields
            )
        elif is_cls_namedtuple(data_cls):
            self._named_tuple_cls = data_cls
            fields = data_cls._fields
        else:
            raise ValueError(
                "Only dataclasses and NamedTuple subclasses are supported."
            )

        hints = get_type_hints(data_cls)

        self._build_table([], [(f, hints[f]) for f in fields])

    def index_of(self, instance):
        """
        Returns record index of given instance in table.
        :param instance:
        :return:
        """
        index = self._as_index(instance)
        return self._rindex.get(index)

    def get(self, **kw):
        """
        Returns instance for given arguments set
        :param kw:
        :return:
        """
        index = self._as_index(kw)
        return self._table[self._rindex[index]]

    def __len__(self):
        return len(self._table)

    def __getitem__(self, i: Union[int, slice]):
        return self._table[i]

    def __iter__(self) -> Iterator:
        return self._table.__iter__()

    def _build_table(self, defined_fields, remained_fields):
        if not remained_fields:
            instance = self._named_tuple_cls(**dict(defined_fields))
            item_id = len(self._table)
            self._index[item_id] = instance
            self._rindex[self._as_index(defined_fields)] = item_id
            self._table.append(instance)
            return

        next_name, next_type = remained_fields[0]
        remained_fields = remained_fields[1:]

        if issubclass(next_type, Enum):
            for v in next_type:
                self._build_table(
                    defined_fields + [(next_name, v)],
                    remained_fields
                )
            return

        if next_type is bool:
            self._build_table(
                defined_fields + [(next_name, False)],
                remained_fields
            )
            self._build_table(
                defined_fields + [(next_name, True)],
                remained_fields
            )
            return

        raise ValueError(f"Got unexpected dataclass field type: {next_type}")

    @staticmethod
    def _as_index(v: Union[Any, Tuple[str, Any]]):
        items = None
        if dataclasses.is_dataclass(v):
            items = dataclasses.asdict(v).items()
        elif is_cls_namedtuple(type(v)):
            items = v._asdict().items()
        elif isinstance(v, dict):
            items = v.items()
        else:
            assert isinstance(v, collections.Sequence)
            items = v

        return tuple(sorted(items, key=lambda x: x[0]))

测试数据表.py

import dataclasses
from enum import Enum, auto
from typing import NamedTuple

import pytest

from dataclass_utils import DataTable


class E(Enum):
    v1 = auto()
    v2 = auto()
    v3 = auto()


@dataclasses.dataclass
class BoolsEndEnums:
    a: E
    b: bool


class BoolsEndEnumsNamedTuple(NamedTuple):
    a: E
    b: bool


@dataclasses.dataclass
class HugeSetOfValues:
    a: int
    b: bool


class NotSupportedCls:
    pass


def _test_cls(cls):
    tbl = DataTable(cls)

    first = cls(E.v1, False)
    last = cls(E.v3, True)

    expected_num_entries = 6

    assert tbl.index_of(first) == 0
    assert tbl.index_of(last) == (expected_num_entries - 1)
    assert len(tbl) == expected_num_entries

    actual_third = tbl.get(a=E.v2, b=False)
    assert actual_third.a == E.v2
    assert actual_third.b is False

    actual_forth = tbl[3]
    assert actual_forth.a == E.v2
    assert actual_forth.b is True

    items = [item for item in tbl]

    actual_fifth = items[4]
    assert actual_fifth.a == E.v3
    assert actual_fifth.b is False

    # Test that we can't change result
    with pytest.raises(AttributeError):
        tbl[0].a = E.v2


def test_dataclass():
    _test_cls(BoolsEndEnums)


def test_namedtuple():
    _test_cls(BoolsEndEnumsNamedTuple)


def test_datatable_neg():
    """
    Generic negative tests
    """
    with pytest.raises(ValueError):
        DataTable(HugeSetOfValues)

    with pytest.raises(ValueError):
        DataTable(NotSupportedCls)

您可以使用enum执行此操作,数据元组作为enum成员的值(如果愿意,可以使用Enum/NamedTuple混合体)。_ignore_属性用于防止类命名空间中的某些名称转换为枚举成员

from itertools import product
from enum import Enum

class Data(Enum):
    _ignore_ = "Data", "myenum_member", "truthiness"

    @property
    def a(self):
        return self.value[0]

    @property
    def b(self):
        return self.value[1]

    def __repr__(self):
        return f'Data(a={self.a!r}, b={self.b!r})'

    Data = vars()
    for myenum_member, truthiness in product(MyEnum, (True, False)):
        Data[f'{myenum_member.name}_{truthiness}'] = (myenum_member, truthiness)

您应该能够按照自己的意愿迭代生成的枚举类

枚举的这种使用类似于文档的enumhowto部分中的"time period" example


动态生成此类表

如果您想动态生成此类表,可以使用元类执行如下操作(ab)。我已经展示了如何在docstrings中使用这个DataTable类的示例用法。(出于某种原因,在doctest中使用typing.get_type_hints似乎会导致doctest模块出错,但如果您自己在交互式终端中尝试这些示例,它们确实可以工作。)与您在回答中使用的特殊大小写bool不同,我决定使用特殊大小写typing.Literal,因为它似乎是一个更可扩展的选项(和bool可以拼写为typing.Literal[True, False]

from __future__ import annotations
from itertools import product
from enum import Enum, EnumMeta

from typing import (
    Iterable,
    Mapping,
    cast,
    Protocol,
    get_type_hints,
    Any,
    get_args,
    get_origin,
    Literal,
    TypeVar,
    Union,
    Optional
)

D = TypeVar('D')
T = TypeVar('T')


class DataTableFactory(EnumMeta):
    """A helper class for making data tables (an implementation detail of `DataTable`)."""

    _CLS_BASES = (Enum,)

    @classmethod
    def __prepare__(  # type: ignore[override]
            metacls,
            cls_name: str,
            fields: Mapping[str, Iterable[Any]]
    ) -> dict[str, Any]:

        cls_dict = cast(
            dict[str, Any],
            super().__prepare__(cls_name, metacls._CLS_BASES)
        )

        for i, field in enumerate(fields.keys()):
            cls_dict[field] = property(fget=lambda self, i=i: self.value[i])  # type: ignore[misc]

        for p in product(*fields.values()):
            cls_dict['_'.join(map(str, p))] = p

        def __repr__(self: Enum) -> str:
            contents = ', '.join(
                f'{field}={getattr(self, field)!r}'
                for field in fields
            )
            return f'{cls_name}Member({contents})'

        cls_dict['__repr__'] = __repr__
        return cls_dict

    @classmethod
    def make_datatable(
            metacls,
            cls_name: str,
            *,
            fields: Mapping[str, Iterable[Any]],
            doc: Optional[str] = None
    ) -> type[Enum]:
        """Create a new data table"""

        cls_dict = metacls.__prepare__(cls_name, fields)
        new_cls = metacls.__new__(metacls, cls_name, metacls._CLS_BASES, cls_dict)
        new_cls.__module__ = __name__

        if doc is None:
            all_attrs = '\n'.join(
                f'    {f"{attr_name}: ":<{(max(map(len, fields)) + 3)}}one of {attr_val!r}'
                for attr_name, attr_val in fields.items()
            )

            fields_len = len(fields)

            doc = (
                f'An enum-like data table.\n\n'
                f'All members of this data table have {fields_len} '
                f'read-only attribute{"s" if fields_len > 1 else ""}:\n'
                f'{all_attrs}\n\n'
                f'                                   '
            )

        new_cls.__doc__ = doc
        return cast(type[Enum], new_cls)

    def __repr__(cls) -> str:
        return f"<Data table '{cls.__name__}'>"

    def index_of(cls: Iterable[D], member: D) -> int:
        """Get the index of a member in the list of members."""
        return list(cls).index(member)

    def get(
            cls: Iterable[D],
            /,
            *,
            default_: Optional[T] = None,
            **kwargs: Any
    ) -> Union[D, T, None]:
        """Return instance for given arguments set.
        Return `default_` if no member matches those arguments.
        """

        it = (
            member for member in cls
            if all((getattr(member, key) == val) for key, val in kwargs.items())
        )

        return next(it, default_)

    def __dir__(cls) -> list[str]:
        # By defining __dir__, we make methods defined in this class
        # discoverable by the interactive help() function in the REPL
        return cast(list[str], super().__dir__()) + ['index_of', 'get']


class TypedStructProto(Protocol):
    """In order to satisfy this interface, a type must have an __annotations__ dict."""
    __annotations__: dict[str, Union[Iterable[Any], type[Literal[True]]]]


class DataTableMeta(type):
    """Metaclass for `DataTable`."""
    __call__ = DataTableFactory.make_datatable  # type: ignore[assignment]


class DataTable(metaclass=DataTableMeta):
    """A mechanism to create 'data table enumerations'   not really a class at all!

    Example usage
          -
    >>> Cars = DataTable('Cars', fields={'make': ('Toyota', 'Audi'), 'colour': ('Red', 'Blue')})
    >>> Cars
    <Data table 'Cars'>
    >>> list(Cars)
    [CarsMember(make=Toyota, colour=Red), CarsMember(make=Toyota, colour=Blue), CarsMember(make=Audi, colour=Red), CarsMember(make=Audi, colour=Blue)]
    >>> Cars.get(make='Audi', colour='Red')
    CarsMember(make=Audi, colour=Red)
    >>> Cars.index_of(_)
    2
    """

    @classmethod
    def from_struct(cls, cls_name: str, *, struct: type[TypedStructProto], doc: Optional[str] = None) -> type[Enum]:
        """Make a DataTable from a "typed struct"   e.g. a dataclass, NamedTuple or TypedDict.

        Example usage (works the same way with dataclasses and TypedDicts)
                                         -
        >>> from enum import Enum, auto
        >>> from typing import NamedTuple, Literal
        >>> class E(Enum):
        ...     v1 = auto()
        ...     v2 = auto()
        ...     v3 = auto()
        ...
        >>> class BoolsEndEnums(NamedTuple):
        ...     a: E
        ...     b: Literal[True, False]
        ...
        >>> BoolsEndEnumsTable = DataTable.from_struct('BoolsEndEnumsTable', struct=BoolsEndEnums)
        >>> list(BoolsEndEnumsTable)
        [BoolsEndEnumsTableMember(a=E.v1, b=True), BoolsEndEnumsTableMember(a=E.v1, b=False), BoolsEndEnumsTableMember(a=E.v2, b=True), BoolsEndEnumsTableMember(a=E.v2, b=False), BoolsEndEnumsTableMember(a=E.v3, b=True), BoolsEndEnumsTableMember(a=E.v3, b=False)]
        """

        fields = get_type_hints(struct)

        for field_name, field_val in fields.items():
            if get_origin(field_val) is Literal:
                fields[field_name] = get_args(field_val)

        return cast(type[Enum], cls(cls_name, fields=fields, doc=doc))  # type: ignore[call-arg]

我不得不用类型提示做一些“有趣”的事情,但MyPy对所有这些都很在行

相关问题 更多 >