diff options
author | Jason Zhang <yurenzhang2017@gmail.com> | 2024-02-19 22:36:11 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-19 14:36:11 -0800 |
commit | c2cb31bbe1262213085c425bc853d6587c66cae9 (patch) | |
tree | 6c3f818a795cfe2ec1209a7b6525f875a3310d54 /Lib/enum.py | |
parent | 6cd18c75a41a74cab69ebef0b7def3e48421bdd1 (diff) | |
download | cpython-c2cb31bbe1262213085c425bc853d6587c66cae9.tar.gz cpython-c2cb31bbe1262213085c425bc853d6587c66cae9.zip |
gh-115539: Allow enum.Flag to have None members (GH-115636)
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 57 |
1 files changed, 36 insertions, 21 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 98a8966f5eb..d10b9961598 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -279,9 +279,10 @@ class _proto_member: enum_member._sort_order_ = len(enum_class._member_names_) if Flag is not None and issubclass(enum_class, Flag): - enum_class._flag_mask_ |= value - if _is_single_bit(value): - enum_class._singles_mask_ |= value + if isinstance(value, int): + enum_class._flag_mask_ |= value + if _is_single_bit(value): + enum_class._singles_mask_ |= value enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 # If another member with the same value was already defined, the @@ -309,6 +310,7 @@ class _proto_member: elif ( Flag is not None and issubclass(enum_class, Flag) + and isinstance(value, int) and _is_single_bit(value) ): # no other instances found, record this member in _member_names_ @@ -1558,37 +1560,50 @@ class Flag(Enum, boundary=STRICT): def __bool__(self): return bool(self._value_) + def _get_value(self, flag): + if isinstance(flag, self.__class__): + return flag._value_ + elif self._member_type_ is not object and isinstance(flag, self._member_type_): + return flag + return NotImplemented + def __or__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with |") value = self._value_ - return self.__class__(value | other) + return self.__class__(value | other_value) def __and__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with &") value = self._value_ - return self.__class__(value & other) + return self.__class__(value & other_value) def __xor__(self, other): - if isinstance(other, self.__class__): - other = other._value_ - elif self._member_type_ is not object and isinstance(other, self._member_type_): - other = other - else: + other_value = self._get_value(other) + if other_value is NotImplemented: return NotImplemented + + for flag in self, other: + if self._get_value(flag) is None: + raise TypeError(f"'{flag}' cannot be combined with other flags with ^") value = self._value_ - return self.__class__(value ^ other) + return self.__class__(value ^ other_value) def __invert__(self): + if self._get_value(self) is None: + raise TypeError(f"'{self}' cannot be inverted") + if self._inverted_ is None: if self._boundary_ in (EJECT, KEEP): self._inverted_ = self.__class__(~self._value_) |