aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2025-03-04 11:44:19 -0800
committerGitHub <noreply@github.com>2025-03-04 11:44:19 -0800
commitdc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada (patch)
tree045fed4b7965d56ea45c009dad6dddb42d7be8b0
parente091520fdbcfe406e5fdcf66b7864b2b34a6726b (diff)
downloadcpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.tar.gz
cpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.zip
gh-105499: Merge typing.Union and types.UnionType (#105511)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com> Co-authored-by: Ken Jin <kenjin@python.org> Co-authored-by: Carl Meyer <carl@oddbird.net>
-rw-r--r--Doc/deprecations/pending-removal-in-future.rst5
-rw-r--r--Doc/library/functools.rst6
-rw-r--r--Doc/library/stdtypes.rst21
-rw-r--r--Doc/library/types.rst4
-rw-r--r--Doc/library/typing.rst10
-rw-r--r--Doc/whatsnew/3.10.rst4
-rw-r--r--Doc/whatsnew/3.11.rst4
-rw-r--r--Include/internal/pycore_unionobject.h1
-rw-r--r--Lib/functools.py17
-rw-r--r--Lib/test/test_dataclasses/__init__.py2
-rw-r--r--Lib/test/test_functools.py2
-rw-r--r--Lib/test/test_inspect/test_inspect.py4
-rw-r--r--Lib/test/test_pydoc/test_pydoc.py16
-rw-r--r--Lib/test/test_types.py50
-rw-r--r--Lib/test/test_typing.py96
-rw-r--r--Lib/typing.py144
-rw-r--r--Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst3
-rw-r--r--Modules/_typingmodule.c8
-rw-r--r--Objects/typevarobject.c14
-rw-r--r--Objects/unionobject.c438
20 files changed, 542 insertions, 307 deletions
diff --git a/Doc/deprecations/pending-removal-in-future.rst b/Doc/deprecations/pending-removal-in-future.rst
index 42dce518717..df8d18782ce 100644
--- a/Doc/deprecations/pending-removal-in-future.rst
+++ b/Doc/deprecations/pending-removal-in-future.rst
@@ -127,6 +127,11 @@ although there is currently no date scheduled for their removal.
* :class:`typing.Text` (:gh:`92332`).
+* The internal class ``typing._UnionGenericAlias`` is no longer used to implement
+ :class:`typing.Union`. To preserve compatibility with users using this private
+ class, a compatibility shim will be provided until at least Python 3.17. (Contributed by
+ Jelle Zijlstra in :gh:`105499`.)
+
* :class:`unittest.IsolatedAsyncioTestCase`: it is deprecated to return a value
that is not ``None`` from a test case.
diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst
index 8ad5f48c9e5..3a933dff057 100644
--- a/Doc/library/functools.rst
+++ b/Doc/library/functools.rst
@@ -518,7 +518,7 @@ The :mod:`functools` module defines the following functions:
... for i, elem in enumerate(arg):
... print(i, elem)
- :data:`types.UnionType` and :data:`typing.Union` can also be used::
+ :class:`typing.Union` can also be used::
>>> @fun.register
... def _(arg: int | float, verbose=False):
@@ -654,8 +654,8 @@ The :mod:`functools` module defines the following functions:
The :func:`register` attribute now supports using type annotations.
.. versionchanged:: 3.11
- The :func:`register` attribute now supports :data:`types.UnionType`
- and :data:`typing.Union` as type annotations.
+ The :func:`register` attribute now supports
+ :class:`typing.Union` as a type annotation.
.. class:: singledispatchmethod(func)
diff --git a/Doc/library/stdtypes.rst b/Doc/library/stdtypes.rst
index 0564981b52e..a6260ecd77f 100644
--- a/Doc/library/stdtypes.rst
+++ b/Doc/library/stdtypes.rst
@@ -5364,7 +5364,7 @@ Union Type
A union object holds the value of the ``|`` (bitwise or) operation on
multiple :ref:`type objects <bltin-type-objects>`. These types are intended
primarily for :term:`type annotations <annotation>`. The union type expression
-enables cleaner type hinting syntax compared to :data:`typing.Union`.
+enables cleaner type hinting syntax compared to subscripting :class:`typing.Union`.
.. describe:: X | Y | ...
@@ -5400,9 +5400,10 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`.
int | str == str | int
- * It is compatible with :data:`typing.Union`::
+ * It creates instances of :class:`typing.Union`::
int | str == typing.Union[int, str]
+ type(int | str) is typing.Union
* Optional types can be spelled as a union with ``None``::
@@ -5428,16 +5429,15 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`.
TypeError: isinstance() argument 2 cannot be a parameterized generic
The user-exposed type for the union object can be accessed from
-:data:`types.UnionType` and used for :func:`isinstance` checks. An object cannot be
-instantiated from the type::
+:class:`typing.Union` and used for :func:`isinstance` checks::
- >>> import types
- >>> isinstance(int | str, types.UnionType)
+ >>> import typing
+ >>> isinstance(int | str, typing.Union)
True
- >>> types.UnionType()
+ >>> typing.Union()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
- TypeError: cannot create 'types.UnionType' instances
+ TypeError: cannot create 'typing.Union' instances
.. note::
The :meth:`!__or__` method for type objects was added to support the syntax
@@ -5464,6 +5464,11 @@ instantiated from the type::
.. versionadded:: 3.10
+.. versionchanged:: 3.14
+
+ Union objects are now instances of :class:`typing.Union`. Previously, they were instances
+ of :class:`types.UnionType`, which remains an alias for :class:`typing.Union`.
+
.. _typesother:
diff --git a/Doc/library/types.rst b/Doc/library/types.rst
index 439e119461f..2bedd7fdd3c 100644
--- a/Doc/library/types.rst
+++ b/Doc/library/types.rst
@@ -314,6 +314,10 @@ Standard names are defined for the following types:
.. versionadded:: 3.10
+ .. versionchanged:: 3.14
+
+ This is now an alias for :class:`typing.Union`.
+
.. class:: TracebackType(tb_next, tb_frame, tb_lasti, tb_lineno)
The type of traceback objects such as found in ``sys.exception().__traceback__``.
diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst
index 13108abdeaf..6ea4272eecf 100644
--- a/Doc/library/typing.rst
+++ b/Doc/library/typing.rst
@@ -1086,7 +1086,7 @@ Special forms
These can be used as types in annotations. They all support subscription using
``[]``, but each has a unique syntax.
-.. data:: Union
+.. class:: Union
Union type; ``Union[X, Y]`` is equivalent to ``X | Y`` and means either X or Y.
@@ -1121,6 +1121,14 @@ These can be used as types in annotations. They all support subscription using
Unions can now be written as ``X | Y``. See
:ref:`union type expressions<types-union>`.
+ .. versionchanged:: 3.14
+ :class:`types.UnionType` is now an alias for :class:`Union`, and both
+ ``Union[int, str]`` and ``int | str`` create instances of the same class.
+ To check whether an object is a ``Union`` at runtime, use
+ ``isinstance(obj, Union)``. For compatibility with earlier versions of
+ Python, use
+ ``get_origin(obj) is typing.Union or get_origin(obj) is types.UnionType``.
+
.. data:: Optional
``Optional[X]`` is equivalent to ``X | None`` (or ``Union[X, None]``).
diff --git a/Doc/whatsnew/3.10.rst b/Doc/whatsnew/3.10.rst
index e4699fbf8ed..3c815721a92 100644
--- a/Doc/whatsnew/3.10.rst
+++ b/Doc/whatsnew/3.10.rst
@@ -722,10 +722,10 @@ PEP 604: New Type Union Operator
A new type union operator was introduced which enables the syntax ``X | Y``.
This provides a cleaner way of expressing 'either type X or type Y' instead of
-using :data:`typing.Union`, especially in type hints.
+using :class:`typing.Union`, especially in type hints.
In previous versions of Python, to apply a type hint for functions accepting
-arguments of multiple types, :data:`typing.Union` was used::
+arguments of multiple types, :class:`typing.Union` was used::
def square(number: Union[int, float]) -> Union[int, float]:
return number ** 2
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index ed41ecd50b0..2dd205dd2b8 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -740,8 +740,8 @@ fractions
functools
---------
-* :func:`functools.singledispatch` now supports :data:`types.UnionType`
- and :data:`typing.Union` as annotations to the dispatch argument.::
+* :func:`functools.singledispatch` now supports :class:`types.UnionType`
+ and :class:`typing.Union` as annotations to the dispatch argument.::
>>> from functools import singledispatch
>>> @singledispatch
diff --git a/Include/internal/pycore_unionobject.h b/Include/internal/pycore_unionobject.h
index 6ece7134cde..4bd36f6504d 100644
--- a/Include/internal/pycore_unionobject.h
+++ b/Include/internal/pycore_unionobject.h
@@ -18,6 +18,7 @@ PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject *, PyObject *);
extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
extern PyObject *_Py_make_parameters(PyObject *);
extern PyObject *_Py_union_args(PyObject *self);
+extern PyObject *_Py_union_from_tuple(PyObject *args);
#ifdef __cplusplus
}
diff --git a/Lib/functools.py b/Lib/functools.py
index 70c59b109d9..5e2579f6d8e 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -926,16 +926,11 @@ def singledispatch(func):
dispatch_cache[cls] = impl
return impl
- def _is_union_type(cls):
- from typing import get_origin, Union
- return get_origin(cls) in {Union, UnionType}
-
def _is_valid_dispatch_type(cls):
if isinstance(cls, type):
return True
- from typing import get_args
- return (_is_union_type(cls) and
- all(isinstance(arg, type) for arg in get_args(cls)))
+ return (isinstance(cls, UnionType) and
+ all(isinstance(arg, type) for arg in cls.__args__))
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@@ -967,7 +962,7 @@ def singledispatch(func):
from annotationlib import Format, ForwardRef
argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
if not _is_valid_dispatch_type(cls):
- if _is_union_type(cls):
+ if isinstance(cls, UnionType):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} not all arguments are classes."
@@ -983,10 +978,8 @@ def singledispatch(func):
f"{cls!r} is not a class."
)
- if _is_union_type(cls):
- from typing import get_args
-
- for arg in get_args(cls):
+ if isinstance(cls, UnionType):
+ for arg in cls.__args__:
registry[arg] = func
else:
registry[cls] = func
diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py
index 8209374c36b..869a043211b 100644
--- a/Lib/test/test_dataclasses/__init__.py
+++ b/Lib/test/test_dataclasses/__init__.py
@@ -2314,7 +2314,7 @@ class TestDocString(unittest.TestCase):
class C:
x: Union[int, type(None)] = None
- self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
+ self.assertDocStrEqual(C.__doc__, "C(x:int|None=None)")
def test_docstring_list_field(self):
@dataclass
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index d7404a81c23..ef85664cb78 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -3083,7 +3083,7 @@ class TestSingleDispatch(unittest.TestCase):
"Invalid annotation for 'arg'."
)
self.assertEndsWith(str(exc.exception),
- 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
+ 'int | typing.Iterable[str] not all arguments are classes.'
)
def test_invalid_positional_argument(self):
diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py
index 03f2bacb3a4..73cf5ac64ee 100644
--- a/Lib/test/test_inspect/test_inspect.py
+++ b/Lib/test/test_inspect/test_inspect.py
@@ -1750,8 +1750,8 @@ class TestClassesAndFunctions(unittest.TestCase):
class TestFormatAnnotation(unittest.TestCase):
def test_typing_replacement(self):
from test.typinganndata.ann_module9 import ann, ann1
- self.assertEqual(inspect.formatannotation(ann), 'Union[List[str], int]')
- self.assertEqual(inspect.formatannotation(ann1), 'Union[List[testModule.typing.A], int]')
+ self.assertEqual(inspect.formatannotation(ann), 'List[str] | int')
+ self.assertEqual(inspect.formatannotation(ann1), 'List[testModule.typing.A] | int')
def test_forwardref(self):
fwdref = ForwardRef('fwdref')
diff --git a/Lib/test/test_pydoc/test_pydoc.py b/Lib/test/test_pydoc/test_pydoc.py
index 0abd36c5e07..2b1a4484c68 100644
--- a/Lib/test/test_pydoc/test_pydoc.py
+++ b/Lib/test/test_pydoc/test_pydoc.py
@@ -133,7 +133,7 @@ DATA
c_alias = test.test_pydoc.pydoc_mod.C[int]
list_alias1 = typing.List[int]
list_alias2 = list[int]
- type_union1 = typing.Union[int, str]
+ type_union1 = int | str
type_union2 = int | str
VERSION
@@ -223,7 +223,7 @@ Data
c_alias = test.test_pydoc.pydoc_mod.C[int]
list_alias1 = typing.List[int]
list_alias2 = list[int]
- type_union1 = typing.Union[int, str]
+ type_union1 = int | str
type_union2 = int | str
Author
@@ -1447,17 +1447,17 @@ class TestDescriptions(unittest.TestCase):
self.assertIn(list.__doc__.strip().splitlines()[0], doc)
def test_union_type(self):
- self.assertEqual(pydoc.describe(typing.Union[int, str]), '_UnionGenericAlias')
+ self.assertEqual(pydoc.describe(typing.Union[int, str]), 'Union')
doc = pydoc.render_doc(typing.Union[int, str], renderer=pydoc.plaintext)
- self.assertIn('_UnionGenericAlias in module typing', doc)
- self.assertIn('Union = typing.Union', doc)
+ self.assertIn('Union in module typing', doc)
+ self.assertIn('class Union(builtins.object)', doc)
if typing.Union.__doc__:
self.assertIn(typing.Union.__doc__.strip().splitlines()[0], doc)
- self.assertEqual(pydoc.describe(int | str), 'UnionType')
+ self.assertEqual(pydoc.describe(int | str), 'Union')
doc = pydoc.render_doc(int | str, renderer=pydoc.plaintext)
- self.assertIn('UnionType in module types object', doc)
- self.assertIn('\nclass UnionType(builtins.object)', doc)
+ self.assertIn('Union in module typing', doc)
+ self.assertIn('class Union(builtins.object)', doc)
if not MISSING_C_DOCSTRINGS:
self.assertIn(types.UnionType.__doc__.strip().splitlines()[0], doc)
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index d1161719d98..5a65b5dacaf 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -709,10 +709,6 @@ class UnionTests(unittest.TestCase):
y = int | bool
with self.assertRaises(TypeError):
x < y
- # Check that we don't crash if typing.Union does not have a tuple in __args__
- y = typing.Union[str, int]
- y.__args__ = [str, int]
- self.assertEqual(x, y)
def test_hash(self):
self.assertEqual(hash(int | str), hash(str | int))
@@ -727,17 +723,40 @@ class UnionTests(unittest.TestCase):
self.assertEqual((A | B).__args__, (A, B))
union1 = A | B
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union1)
union2 = int | B
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union2)
union3 = A | int
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union3)
+ def test_unhashable_becomes_hashable(self):
+ is_hashable = False
+ class UnhashableMeta(type):
+ def __hash__(self):
+ if is_hashable:
+ return 1
+ else:
+ raise TypeError("not hashable")
+
+ class A(metaclass=UnhashableMeta): ...
+ class B(metaclass=UnhashableMeta): ...
+
+ union = A | B
+ self.assertEqual(union.__args__, (A, B))
+
+ with self.assertRaisesRegex(TypeError, "not hashable"):
+ hash(union)
+
+ is_hashable = True
+
+ with self.assertRaisesRegex(TypeError, "union contains 2 unhashable elements"):
+ hash(union)
+
def test_instancecheck_and_subclasscheck(self):
for x in (int | str, typing.Union[int, str]):
with self.subTest(x=x):
@@ -921,7 +940,7 @@ class UnionTests(unittest.TestCase):
self.assertEqual(typing.get_args(typing.get_type_hints(forward_after)['x']),
(int, Forward))
self.assertEqual(typing.get_args(typing.get_type_hints(forward_before)['x']),
- (int, Forward))
+ (Forward, int))
def test_or_type_operator_with_Protocol(self):
class Proto(typing.Protocol):
@@ -1015,9 +1034,14 @@ class UnionTests(unittest.TestCase):
return 1 / 0
bt = BadType('bt', (), {})
+ bt2 = BadType('bt2', (), {})
# Comparison should fail and errors should propagate out for bad types.
+ union1 = int | bt
+ union2 = int | bt2
+ with self.assertRaises(ZeroDivisionError):
+ union1 == union2
with self.assertRaises(ZeroDivisionError):
- list[int] | list[bt]
+ bt | bt2
union_ga = (list[str] | int, collections.abc.Callable[..., str] | int,
d | int)
@@ -1060,6 +1084,14 @@ class UnionTests(unittest.TestCase):
self.assertLessEqual(sys.gettotalrefcount() - before, leeway,
msg='Check for union reference leak.')
+ def test_instantiation(self):
+ with self.assertRaises(TypeError):
+ types.UnionType()
+ self.assertIs(int, types.UnionType[int])
+ self.assertIs(int, types.UnionType[int, int])
+ self.assertEqual(int | str, types.UnionType[int, str])
+ self.assertEqual(int | typing.ForwardRef("str"), types.UnionType[int, "str"])
+
class MappingProxyTests(unittest.TestCase):
mappingproxy = types.MappingProxyType
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 591fb860eee..e88c811bfca 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -502,7 +502,7 @@ class TypeVarTests(BaseTestCase):
def test_bound_errors(self):
with self.assertRaises(TypeError):
- TypeVar('X', bound=Union)
+ TypeVar('X', bound=Optional)
with self.assertRaises(TypeError):
TypeVar('X', str, float, bound=Employee)
with self.assertRaisesRegex(TypeError,
@@ -542,7 +542,7 @@ class TypeVarTests(BaseTestCase):
def test_bad_var_substitution(self):
T = TypeVar('T')
bad_args = (
- (), (int, str), Union,
+ (), (int, str), Optional,
Generic, Generic[T], Protocol, Protocol[T],
Final, Final[int], ClassVar, ClassVar[int],
)
@@ -2044,10 +2044,6 @@ class UnionTests(BaseTestCase):
def test_union_issubclass_type_error(self):
with self.assertRaises(TypeError):
- issubclass(int, Union)
- with self.assertRaises(TypeError):
- issubclass(Union, int)
- with self.assertRaises(TypeError):
issubclass(Union[int, str], int)
with self.assertRaises(TypeError):
issubclass(int, Union[str, list[int]])
@@ -2121,41 +2117,40 @@ class UnionTests(BaseTestCase):
self.assertEqual(Union[A, B].__args__, (A, B))
union1 = Union[A, B]
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union1)
union2 = Union[int, B]
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union2)
union3 = Union[A, int]
- with self.assertRaises(TypeError):
+ with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"):
hash(union3)
def test_repr(self):
- self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int]
- self.assertEqual(repr(u), 'typing.Union[%s.Employee, int]' % __name__)
+ self.assertEqual(repr(u), f'{__name__}.Employee | int')
u = Union[int, Employee]
- self.assertEqual(repr(u), 'typing.Union[int, %s.Employee]' % __name__)
+ self.assertEqual(repr(u), f'int | {__name__}.Employee')
T = TypeVar('T')
u = Union[T, int][int]
self.assertEqual(repr(u), repr(int))
u = Union[List[int], int]
- self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]')
+ self.assertEqual(repr(u), 'typing.List[int] | int')
u = Union[list[int], dict[str, float]]
- self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]')
+ self.assertEqual(repr(u), 'list[int] | dict[str, float]')
u = Union[int | float]
- self.assertEqual(repr(u), 'typing.Union[int, float]')
+ self.assertEqual(repr(u), 'int | float')
u = Union[None, str]
- self.assertEqual(repr(u), 'typing.Optional[str]')
+ self.assertEqual(repr(u), 'None | str')
u = Union[str, None]
- self.assertEqual(repr(u), 'typing.Optional[str]')
+ self.assertEqual(repr(u), 'str | None')
u = Union[None, str, int]
- self.assertEqual(repr(u), 'typing.Union[NoneType, str, int]')
+ self.assertEqual(repr(u), 'None | str | int')
u = Optional[str]
- self.assertEqual(repr(u), 'typing.Optional[str]')
+ self.assertEqual(repr(u), 'str | None')
def test_dir(self):
dir_items = set(dir(Union[str, int]))
@@ -2167,14 +2162,11 @@ class UnionTests(BaseTestCase):
def test_cannot_subclass(self):
with self.assertRaisesRegex(TypeError,
- r'Cannot subclass typing\.Union'):
+ r"type 'typing\.Union' is not an acceptable base type"):
class C(Union):
pass
- with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE):
- class D(type(Union)):
- pass
with self.assertRaisesRegex(TypeError,
- r'Cannot subclass typing\.Union\[int, str\]'):
+ r'Cannot subclass int \| str'):
class E(Union[int, str]):
pass
@@ -2220,7 +2212,7 @@ class UnionTests(BaseTestCase):
def test_function_repr_union(self):
def fun() -> int: ...
- self.assertEqual(repr(Union[fun, int]), 'typing.Union[fun, int]')
+ self.assertEqual(repr(Union[fun, int]), f'{__name__}.{fun.__qualname__} | int')
def test_union_str_pattern(self):
# Shouldn't crash; see http://bugs.python.org/issue25390
@@ -4895,11 +4887,11 @@ class GenericTests(BaseTestCase):
def test_extended_generic_rules_repr(self):
T = TypeVar('T')
self.assertEqual(repr(Union[Tuple, Callable]).replace('typing.', ''),
- 'Union[Tuple, Callable]')
+ 'Tuple | Callable')
self.assertEqual(repr(Union[Tuple, Tuple[int]]).replace('typing.', ''),
- 'Union[Tuple, Tuple[int]]')
+ 'Tuple | Tuple[int]')
self.assertEqual(repr(Callable[..., Optional[T]][int]).replace('typing.', ''),
- 'Callable[..., Optional[int]]')
+ 'Callable[..., int | None]')
self.assertEqual(repr(Callable[[], List[T]][int]).replace('typing.', ''),
'Callable[[], List[int]]')
@@ -5079,9 +5071,9 @@ class GenericTests(BaseTestCase):
with self.assertRaises(TypeError):
issubclass(Tuple[int, ...], typing.Iterable)
- def test_fail_with_bare_union(self):
+ def test_fail_with_special_forms(self):
with self.assertRaises(TypeError):
- List[Union]
+ List[Final]
with self.assertRaises(TypeError):
Tuple[Optional]
with self.assertRaises(TypeError):
@@ -5623,8 +5615,6 @@ class GenericTests(BaseTestCase):
for obj in (
ClassVar[int],
Final[int],
- Union[int, float],
- Optional[int],
Literal[1, 2],
Concatenate[int, ParamSpec("P")],
TypeGuard[int],
@@ -5656,7 +5646,7 @@ class GenericTests(BaseTestCase):
__parameters__ = (T,)
# Bare classes should be skipped
for a in (List, list):
- for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, types.UnionType):
+ for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, Union):
with self.subTest(generic=a, sub=b):
with self.assertRaisesRegex(TypeError, '.* is not a generic class'):
a[b][str]
@@ -5675,7 +5665,7 @@ class GenericTests(BaseTestCase):
for s in (int, G, A, List, list,
TypeVar, TypeVarTuple, ParamSpec,
- types.GenericAlias, types.UnionType):
+ types.GenericAlias, Union):
for t in Tuple, tuple:
with self.subTest(tuple=t, sub=s):
@@ -7176,7 +7166,7 @@ class GetUtilitiesTestCase(TestCase):
self.assertIs(get_origin(Callable), collections.abc.Callable)
self.assertIs(get_origin(list[int]), list)
self.assertIs(get_origin(list), None)
- self.assertIs(get_origin(list | str), types.UnionType)
+ self.assertIs(get_origin(list | str), Union)
self.assertIs(get_origin(P.args), P)
self.assertIs(get_origin(P.kwargs), P)
self.assertIs(get_origin(Required[int]), Required)
@@ -10434,7 +10424,6 @@ class SpecialAttrsTests(BaseTestCase):
typing.TypeGuard: 'TypeGuard',
typing.TypeIs: 'TypeIs',
typing.TypeVar: 'TypeVar',
- typing.Union: 'Union',
typing.Self: 'Self',
# Subscripted special forms
typing.Annotated[Any, "Annotation"]: 'Annotated',
@@ -10445,7 +10434,7 @@ class SpecialAttrsTests(BaseTestCase):
typing.Literal[Any]: 'Literal',
typing.Literal[1, 2]: 'Literal',
typing.Literal[True, 2]: 'Literal',
- typing.Optional[Any]: 'Optional',
+ typing.Optional[Any]: 'Union',
typing.TypeGuard[Any]: 'TypeGuard',
typing.TypeIs[Any]: 'TypeIs',
typing.Union[Any]: 'Any',
@@ -10464,7 +10453,10 @@ class SpecialAttrsTests(BaseTestCase):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
s = pickle.dumps(cls, proto)
loaded = pickle.loads(s)
- self.assertIs(cls, loaded)
+ if isinstance(cls, Union):
+ self.assertEqual(cls, loaded)
+ else:
+ self.assertIs(cls, loaded)
TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any)
@@ -10739,6 +10731,34 @@ class TypeIterationTests(BaseTestCase):
self.assertNotIsInstance(type_to_test, collections.abc.Iterable)
+class UnionGenericAliasTests(BaseTestCase):
+ def test_constructor(self):
+ # Used e.g. in typer, pydantic
+ with self.assertWarns(DeprecationWarning):
+ inst = typing._UnionGenericAlias(typing.Union, (int, str))
+ self.assertEqual(inst, int | str)
+ with self.assertWarns(DeprecationWarning):
+ # name is accepted but ignored
+ inst = typing._UnionGenericAlias(typing.Union, (int, None), name="Optional")
+ self.assertEqual(inst, int | None)
+
+ def test_isinstance(self):
+ # Used e.g. in pydantic
+ with self.assertWarns(DeprecationWarning):
+ self.assertTrue(isinstance(Union[int, str], typing._UnionGenericAlias))
+ with self.assertWarns(DeprecationWarning):
+ self.assertFalse(isinstance(int, typing._UnionGenericAlias))
+
+ def test_eq(self):
+ # type(t) == _UnionGenericAlias is used in vyos
+ with self.assertWarns(DeprecationWarning):
+ self.assertEqual(Union, typing._UnionGenericAlias)
+ with self.assertWarns(DeprecationWarning):
+ self.assertEqual(typing._UnionGenericAlias, typing._UnionGenericAlias)
+ with self.assertWarns(DeprecationWarning):
+ self.assertNotEqual(int, typing._UnionGenericAlias)
+
+
def load_tests(loader, tests, pattern):
import doctest
tests.addTests(doctest.DocTestSuite(typing))
diff --git a/Lib/typing.py b/Lib/typing.py
index 66570db7a5b..4b3c63b25ae 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -29,7 +29,13 @@ import functools
import operator
import sys
import types
-from types import GenericAlias
+from types import (
+ WrapperDescriptorType,
+ MethodWrapperType,
+ MethodDescriptorType,
+ GenericAlias,
+)
+import warnings
from _typing import (
_idfunc,
@@ -40,6 +46,7 @@ from _typing import (
ParamSpecKwargs,
TypeAliasType,
Generic,
+ Union,
NoDefault,
)
@@ -367,21 +374,6 @@ def _compare_args_orderless(first_args, second_args):
return False
return not t
-def _remove_dups_flatten(parameters):
- """Internal helper for Union creation and substitution.
-
- Flatten Unions among parameters, then remove duplicates.
- """
- # Flatten out Union[Union[...], ...].
- params = []
- for p in parameters:
- if isinstance(p, (_UnionGenericAlias, types.UnionType)):
- params.extend(p.__args__)
- else:
- params.append(p)
-
- return tuple(_deduplicate(params, unhashable_fallback=True))
-
def _flatten_literal_params(parameters):
"""Internal helper for Literal creation: flatten Literals among parameters."""
@@ -470,7 +462,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
return evaluate_forward_ref(t, globals=globalns, locals=localns,
type_params=type_params, owner=owner,
_recursive_guard=recursive_guard, format=format)
- if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
+ if isinstance(t, (_GenericAlias, GenericAlias, Union)):
if isinstance(t, GenericAlias):
args = tuple(
_make_forward_ref(arg) if isinstance(arg, str) else arg
@@ -495,7 +487,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f
return t
if isinstance(t, GenericAlias):
return GenericAlias(t.__origin__, ev_args)
- if isinstance(t, types.UnionType):
+ if isinstance(t, Union):
return functools.reduce(operator.or_, ev_args)
else:
return t.copy_with(ev_args)
@@ -750,59 +742,6 @@ def Final(self, parameters):
return _GenericAlias(self, (item,))
@_SpecialForm
-def Union(self, parameters):
- """Union type; Union[X, Y] means either X or Y.
-
- On Python 3.10 and higher, the | operator
- can also be used to denote unions;
- X | Y means the same thing to the type checker as Union[X, Y].
-
- To define a union, use e.g. Union[int, str]. Details:
- - The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
- - Unions of unions are flattened, e.g.::
-
- assert Union[Union[int, str], float] == Union[int, str, float]
-
- - Unions of a single argument vanish, e.g.::
-
- assert Union[int] == int # The constructor actually returns int
-
- - Redundant arguments are skipped, e.g.::
-
- assert Union[int, str, int] == Union[int, str]
-
- - When comparing unions, the argument order is ignored, e.g.::
-
- assert Union[int, str] == Union[str, int]
-
- - You cannot subclass or instantiate a union.
- - You can use Optional[X] as a shorthand for Union[X, None].
- """
- if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- msg = "Union[arg, ...]: each arg must be a type."
- parameters = tuple(_type_check(p, msg) for p in parameters)
- parameters = _remove_dups_flatten(parameters)
- if len(parameters) == 1:
- return parameters[0]
- if len(parameters) == 2 and type(None) in parameters:
- return _UnionGenericAlias(self, parameters, name="Optional")
- return _UnionGenericAlias(self, parameters)
-
-def _make_union(left, right):
- """Used from the C implementation of TypeVar.
-
- TypeVar.__or__ calls this instead of returning types.UnionType
- because we want to allow unions between TypeVars and strings
- (forward references).
- """
- return Union[left, right]
-
-@_SpecialForm
def Optional(self, parameters):
"""Optional[X] is equivalent to Union[X, None]."""
arg = _type_check(parameters, f"{self} requires a single type.")
@@ -1708,45 +1647,34 @@ class _TupleType(_SpecialGenericAlias, _root=True):
return self.copy_with(params)
-class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
- def copy_with(self, params):
- return Union[params]
+class _UnionGenericAliasMeta(type):
+ def __instancecheck__(self, inst: object) -> bool:
+ warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
+ return isinstance(inst, Union)
- def __eq__(self, other):
- if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
- return NotImplemented
- try: # fast path
- return set(self.__args__) == set(other.__args__)
- except TypeError: # not hashable, slow path
- return _compare_args_orderless(self.__args__, other.__args__)
+ def __subclasscheck__(self, inst: type) -> bool:
+ warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
+ return issubclass(inst, Union)
- def __hash__(self):
- return hash(frozenset(self.__args__))
+ def __eq__(self, other):
+ warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
+ if other is _UnionGenericAlias or other is Union:
+ return True
+ return NotImplemented
- def __repr__(self):
- args = self.__args__
- if len(args) == 2:
- if args[0] is type(None):
- return f'typing.Optional[{_type_repr(args[1])}]'
- elif args[1] is type(None):
- return f'typing.Optional[{_type_repr(args[0])}]'
- return super().__repr__()
- def __instancecheck__(self, obj):
- for arg in self.__args__:
- if isinstance(obj, arg):
- return True
- return False
+class _UnionGenericAlias(metaclass=_UnionGenericAliasMeta):
+ """Compatibility hack.
- def __subclasscheck__(self, cls):
- for arg in self.__args__:
- if issubclass(cls, arg):
- return True
- return False
+ A class named _UnionGenericAlias used to be used to implement
+ typing.Union. This class exists to serve as a shim to preserve
+ the meaning of some code that used to use _UnionGenericAlias
+ directly.
- def __reduce__(self):
- func, (origin, args) = super().__reduce__()
- return func, (Union, args)
+ """
+ def __new__(cls, self_cls, parameters, /, *, name=None):
+ warnings._deprecated("_UnionGenericAlias", remove=(3, 17))
+ return Union[parameters]
def _value_and_type_iter(parameters):
@@ -2472,7 +2400,7 @@ def _strip_annotations(t):
if stripped_args == t.__args__:
return t
return GenericAlias(t.__origin__, stripped_args)
- if isinstance(t, types.UnionType):
+ if isinstance(t, Union):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
@@ -2506,8 +2434,8 @@ def get_origin(tp):
return tp.__origin__
if tp is Generic:
return Generic
- if isinstance(tp, types.UnionType):
- return types.UnionType
+ if isinstance(tp, Union):
+ return Union
return None
@@ -2532,7 +2460,7 @@ def get_args(tp):
if _should_unflatten_callable_args(tp, res):
res = (list(res[:-1]), res[-1])
return res
- if isinstance(tp, types.UnionType):
+ if isinstance(tp, Union):
return tp.__args__
return ()
diff --git a/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst b/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst
new file mode 100644
index 00000000000..5240f4aa7d1
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst
@@ -0,0 +1,3 @@
+Make :class:`types.UnionType` an alias for :class:`typing.Union`. Both
+``int | str`` and ``Union[int, str]`` now create instances of the same
+type. Patch by Jelle Zijlstra.
diff --git a/Modules/_typingmodule.c b/Modules/_typingmodule.c
index 09fbb3c5e8b..e51279c808a 100644
--- a/Modules/_typingmodule.c
+++ b/Modules/_typingmodule.c
@@ -5,9 +5,10 @@
#endif
#include "Python.h"
-#include "pycore_interp.h"
+#include "internal/pycore_interp.h"
+#include "internal/pycore_typevarobject.h"
+#include "internal/pycore_unionobject.h" // _PyUnion_Type
#include "pycore_pystate.h" // _PyInterpreterState_GET()
-#include "pycore_typevarobject.h"
#include "clinic/_typingmodule.c.h"
/*[clinic input]
@@ -63,6 +64,9 @@ _typing_exec(PyObject *m)
if (PyModule_AddObjectRef(m, "TypeAliasType", (PyObject *)&_PyTypeAlias_Type) < 0) {
return -1;
}
+ if (PyModule_AddObjectRef(m, "Union", (PyObject *)&_PyUnion_Type) < 0) {
+ return -1;
+ }
if (PyModule_AddObjectRef(m, "NoDefault", (PyObject *)&_Py_NoDefaultStruct) < 0) {
return -1;
}
diff --git a/Objects/typevarobject.c b/Objects/typevarobject.c
index 3ab8cb14686..ace079dfef1 100644
--- a/Objects/typevarobject.c
+++ b/Objects/typevarobject.c
@@ -2,8 +2,8 @@
#include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK, PyAnnotateFormat
#include "pycore_typevarobject.h"
-#include "pycore_unionobject.h" // _Py_union_type_or
-
+#include "pycore_unionobject.h" // _Py_union_type_or, _Py_union_from_tuple
+#include "structmember.h"
/*[clinic input]
class typevar "typevarobject *" "&_PyTypeVar_Type"
@@ -370,9 +370,13 @@ type_check(PyObject *arg, const char *msg)
static PyObject *
make_union(PyObject *self, PyObject *other)
{
- PyObject *args[2] = {self, other};
- PyObject *result = call_typing_func_object("_make_union", args, 2);
- return result;
+ PyObject *args = PyTuple_Pack(2, self, other);
+ if (args == NULL) {
+ return NULL;
+ }
+ PyObject *u = _Py_union_from_tuple(args);
+ Py_DECREF(args);
+ return u;
}
static PyObject *
diff --git a/Objects/unionobject.c b/Objects/unionobject.c
index 6e65a653a95..065b0b85397 100644
--- a/Objects/unionobject.c
+++ b/Objects/unionobject.c
@@ -1,17 +1,17 @@
-// types.UnionType -- used to represent e.g. Union[int, str], int | str
+// typing.Union -- used to represent e.g. Union[int, str], int | str
#include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
#include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr
#include "pycore_unionobject.h"
-static PyObject *make_union(PyObject *);
-
-
typedef struct {
PyObject_HEAD
- PyObject *args;
+ PyObject *args; // all args (tuple)
+ PyObject *hashable_args; // frozenset or NULL
+ PyObject *unhashable_args; // tuple or NULL
PyObject *parameters;
+ PyObject *weakreflist;
} unionobject;
static void
@@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self)
unionobject *alias = (unionobject *)self;
_PyObject_GC_UNTRACK(self);
+ if (alias->weakreflist != NULL) {
+ PyObject_ClearWeakRefs((PyObject *)alias);
+ }
Py_XDECREF(alias->args);
+ Py_XDECREF(alias->hashable_args);
+ Py_XDECREF(alias->unhashable_args);
Py_XDECREF(alias->parameters);
Py_TYPE(self)->tp_free(self);
}
@@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
{
unionobject *alias = (unionobject *)self;
Py_VISIT(alias->args);
+ Py_VISIT(alias->hashable_args);
+ Py_VISIT(alias->unhashable_args);
Py_VISIT(alias->parameters);
return 0;
}
@@ -39,13 +46,67 @@ static Py_hash_t
union_hash(PyObject *self)
{
unionobject *alias = (unionobject *)self;
- PyObject *args = PyFrozenSet_New(alias->args);
- if (args == NULL) {
- return (Py_hash_t)-1;
+ // If there are any unhashable args, treat this union as unhashable.
+ // Otherwise, two unions might compare equal but have different hashes.
+ if (alias->unhashable_args) {
+ // Attempt to get an error from one of the values.
+ assert(PyTuple_CheckExact(alias->unhashable_args));
+ Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
+ Py_hash_t hash = PyObject_Hash(arg);
+ if (hash == -1) {
+ return -1;
+ }
+ }
+ // The unhashable values somehow became hashable again. Still raise
+ // an error.
+ PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
+ return -1;
}
- Py_hash_t hash = PyObject_Hash(args);
- Py_DECREF(args);
- return hash;
+ return PyObject_Hash(alias->hashable_args);
+}
+
+static int
+unions_equal(unionobject *a, unionobject *b)
+{
+ int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
+ if (result == -1) {
+ return -1;
+ }
+ if (result == 0) {
+ return 0;
+ }
+ if (a->unhashable_args && b->unhashable_args) {
+ Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
+ if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
+ return 0;
+ }
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
+ int result = PySequence_Contains(b->unhashable_args, arg_a);
+ if (result == -1) {
+ return -1;
+ }
+ if (!result) {
+ return 0;
+ }
+ }
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
+ int result = PySequence_Contains(a->unhashable_args, arg_b);
+ if (result == -1) {
+ return -1;
+ }
+ if (!result) {
+ return 0;
+ }
+ }
+ }
+ else if (a->unhashable_args || b->unhashable_args) {
+ return 0;
+ }
+ return 1;
}
static PyObject *
@@ -55,93 +116,128 @@ union_richcompare(PyObject *a, PyObject *b, int op)
Py_RETURN_NOTIMPLEMENTED;
}
- PyObject *a_set = PySet_New(((unionobject*)a)->args);
- if (a_set == NULL) {
+ int equal = unions_equal((unionobject*)a, (unionobject*)b);
+ if (equal == -1) {
return NULL;
}
- PyObject *b_set = PySet_New(((unionobject*)b)->args);
- if (b_set == NULL) {
- Py_DECREF(a_set);
- return NULL;
+ if (op == Py_EQ) {
+ return PyBool_FromLong(equal);
+ }
+ else {
+ return PyBool_FromLong(!equal);
}
- PyObject *result = PyObject_RichCompare(a_set, b_set, op);
- Py_DECREF(b_set);
- Py_DECREF(a_set);
- return result;
}
-static int
-is_same(PyObject *left, PyObject *right)
+typedef struct {
+ PyObject *args; // list
+ PyObject *hashable_args; // set
+ PyObject *unhashable_args; // list or NULL
+ bool is_checked; // whether to call type_check()
+} unionbuilder;
+
+static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
+static PyObject *make_union(unionbuilder *);
+static PyObject *type_check(PyObject *, const char *);
+
+static bool
+unionbuilder_init(unionbuilder *ub, bool is_checked)
{
- int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
- return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
+ ub->args = PyList_New(0);
+ if (ub->args == NULL) {
+ return false;
+ }
+ ub->hashable_args = PySet_New(NULL);
+ if (ub->hashable_args == NULL) {
+ Py_DECREF(ub->args);
+ return false;
+ }
+ ub->unhashable_args = NULL;
+ ub->is_checked = is_checked;
+ return true;
}
-static int
-contains(PyObject **items, Py_ssize_t size, PyObject *obj)
+static void
+unionbuilder_finalize(unionbuilder *ub)
{
- for (Py_ssize_t i = 0; i < size; i++) {
- int is_duplicate = is_same(items[i], obj);
- if (is_duplicate) { // -1 or 1
- return is_duplicate;
- }
- }
- return 0;
+ Py_DECREF(ub->args);
+ Py_DECREF(ub->hashable_args);
+ Py_XDECREF(ub->unhashable_args);
}
-static PyObject *
-merge(PyObject **items1, Py_ssize_t size1,
- PyObject **items2, Py_ssize_t size2)
+static bool
+unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
{
- PyObject *tuple = NULL;
- Py_ssize_t pos = 0;
-
- for (Py_ssize_t i = 0; i < size2; i++) {
- PyObject *arg = items2[i];
- int is_duplicate = contains(items1, size1, arg);
- if (is_duplicate < 0) {
- Py_XDECREF(tuple);
- return NULL;
- }
- if (is_duplicate) {
- continue;
+ Py_hash_t hash = PyObject_Hash(arg);
+ if (hash == -1) {
+ PyErr_Clear();
+ if (ub->unhashable_args == NULL) {
+ ub->unhashable_args = PyList_New(0);
+ if (ub->unhashable_args == NULL) {
+ return false;
+ }
}
-
- if (tuple == NULL) {
- tuple = PyTuple_New(size1 + size2 - i);
- if (tuple == NULL) {
- return NULL;
+ else {
+ int contains = PySequence_Contains(ub->unhashable_args, arg);
+ if (contains < 0) {
+ return false;
}
- for (; pos < size1; pos++) {
- PyObject *a = items1[pos];
- PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
+ if (contains == 1) {
+ return true;
}
}
- PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
- pos++;
+ if (PyList_Append(ub->unhashable_args, arg) < 0) {
+ return false;
+ }
}
-
- if (tuple) {
- (void) _PyTuple_Resize(&tuple, pos);
+ else {
+ int contains = PySet_Contains(ub->hashable_args, arg);
+ if (contains < 0) {
+ return false;
+ }
+ if (contains == 1) {
+ return true;
+ }
+ if (PySet_Add(ub->hashable_args, arg) < 0) {
+ return false;
+ }
}
- return tuple;
+ return PyList_Append(ub->args, arg) == 0;
}
-static PyObject **
-get_types(PyObject **obj, Py_ssize_t *size)
+static bool
+unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
{
- if (*obj == Py_None) {
- *obj = (PyObject *)&_PyNone_Type;
+ if (Py_IsNone(arg)) {
+ arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed
}
- if (_PyUnion_Check(*obj)) {
- PyObject *args = ((unionobject *) *obj)->args;
- *size = PyTuple_GET_SIZE(args);
- return &PyTuple_GET_ITEM(args, 0);
+ else if (_PyUnion_Check(arg)) {
+ PyObject *args = ((unionobject *)arg)->args;
+ return unionbuilder_add_tuple(ub, args);
+ }
+ if (ub->is_checked) {
+ PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
+ if (type == NULL) {
+ return false;
+ }
+ bool result = unionbuilder_add_single_unchecked(ub, type);
+ Py_DECREF(type);
+ return result;
}
else {
- *size = 1;
- return obj;
+ return unionbuilder_add_single_unchecked(ub, arg);
+ }
+}
+
+static bool
+unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
+{
+ Py_ssize_t n = PyTuple_GET_SIZE(tuple);
+ for (Py_ssize_t i = 0; i < n; i++) {
+ if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
+ return false;
+ }
}
+ return true;
}
static int
@@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other)
Py_RETURN_NOTIMPLEMENTED;
}
- Py_ssize_t size1, size2;
- PyObject **items1 = get_types(&self, &size1);
- PyObject **items2 = get_types(&other, &size2);
- PyObject *tuple = merge(items1, size1, items2, size2);
- if (tuple == NULL) {
- if (PyErr_Occurred()) {
- return NULL;
- }
- return Py_NewRef(self);
+ unionbuilder ub;
+ // unchecked because we already checked is_unionable()
+ if (!unionbuilder_init(&ub, false)) {
+ return NULL;
+ }
+ if (!unionbuilder_add_single(&ub, self) ||
+ !unionbuilder_add_single(&ub, other)) {
+ unionbuilder_finalize(&ub);
+ return NULL;
}
- PyObject *new_union = make_union(tuple);
- Py_DECREF(tuple);
+ PyObject *new_union = make_union(&ub);
return new_union;
}
@@ -202,6 +297,18 @@ union_repr(PyObject *self)
goto error;
}
}
+
+#if 0
+ PyUnicodeWriter_WriteUTF8(writer, "|args=", 6);
+ PyUnicodeWriter_WriteRepr(writer, alias->args);
+ PyUnicodeWriter_WriteUTF8(writer, "|h=", 3);
+ PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
+ if (alias->unhashable_args) {
+ PyUnicodeWriter_WriteUTF8(writer, "|u=", 3);
+ PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
+ }
+#endif
+
return PyUnicodeWriter_Finish(writer);
error:
@@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item)
return NULL;
}
- PyObject *res;
- Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
- if (nargs == 0) {
- res = make_union(newargs);
- }
- else {
- res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
- for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
- PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
- Py_SETREF(res, PyNumber_Or(res, arg));
- if (res == NULL) {
- break;
- }
- }
- }
+ PyObject *res = _Py_union_from_tuple(newargs);
Py_DECREF(newargs);
return res;
}
@@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused))
return Py_NewRef(alias->parameters);
}
+static PyObject *
+union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
+{
+ return PyUnicode_FromString("Union");
+}
+
+static PyObject *
+union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
+{
+ return Py_NewRef(&_PyUnion_Type);
+}
+
static PyGetSetDef union_properties[] = {
+ {"__name__", union_name, NULL,
+ PyDoc_STR("Name of the type"), NULL},
+ {"__qualname__", union_name, NULL,
+ PyDoc_STR("Qualified name of the type"), NULL},
+ {"__origin__", union_origin, NULL,
+ PyDoc_STR("Always returns the type"), NULL},
{"__parameters__", union_parameters, (setter)NULL,
PyDoc_STR("Type variables in the types.UnionType."), NULL},
{0}
@@ -306,10 +417,88 @@ _Py_union_args(PyObject *self)
return ((unionobject *) self)->args;
}
+static PyObject *
+call_typing_func_object(const char *name, PyObject **args, size_t nargs)
+{
+ PyObject *typing = PyImport_ImportModule("typing");
+ if (typing == NULL) {
+ return NULL;
+ }
+ PyObject *func = PyObject_GetAttrString(typing, name);
+ if (func == NULL) {
+ Py_DECREF(typing);
+ return NULL;
+ }
+ PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
+ Py_DECREF(func);
+ Py_DECREF(typing);
+ return result;
+}
+
+static PyObject *
+type_check(PyObject *arg, const char *msg)
+{
+ if (Py_IsNone(arg)) {
+ // NoneType is immortal, so don't need an INCREF
+ return (PyObject *)Py_TYPE(arg);
+ }
+ // Fast path to avoid calling into typing.py
+ if (is_unionable(arg)) {
+ return Py_NewRef(arg);
+ }
+ PyObject *message_str = PyUnicode_FromString(msg);
+ if (message_str == NULL) {
+ return NULL;
+ }
+ PyObject *args[2] = {arg, message_str};
+ PyObject *result = call_typing_func_object("_type_check", args, 2);
+ Py_DECREF(message_str);
+ return result;
+}
+
+PyObject *
+_Py_union_from_tuple(PyObject *args)
+{
+ unionbuilder ub;
+ if (!unionbuilder_init(&ub, true)) {
+ return NULL;
+ }
+ if (PyTuple_CheckExact(args)) {
+ if (!unionbuilder_add_tuple(&ub, args)) {
+ return NULL;
+ }
+ }
+ else {
+ if (!unionbuilder_add_single(&ub, args)) {
+ return NULL;
+ }
+ }
+ return make_union(&ub);
+}
+
+static PyObject *
+union_class_getitem(PyObject *cls, PyObject *args)
+{
+ return _Py_union_from_tuple(args);
+}
+
+static PyObject *
+union_mro_entries(PyObject *self, PyObject *args)
+{
+ return PyErr_Format(PyExc_TypeError,
+ "Cannot subclass %R", self);
+}
+
+static PyMethodDef union_methods[] = {
+ {"__mro_entries__", union_mro_entries, METH_O},
+ {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
+ {0}
+};
+
PyTypeObject _PyUnion_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- .tp_name = "types.UnionType",
- .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
+ .tp_name = "typing.Union",
+ .tp_doc = PyDoc_STR("Represent a union type\n"
"\n"
"E.g. for int | str"),
.tp_basicsize = sizeof(unionobject),
@@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = {
.tp_hash = union_hash,
.tp_getattro = union_getattro,
.tp_members = union_members,
+ .tp_methods = union_methods,
.tp_richcompare = union_richcompare,
.tp_as_mapping = &union_as_mapping,
.tp_as_number = &union_as_number,
.tp_repr = union_repr,
.tp_getset = union_properties,
+ .tp_weaklistoffset = offsetof(unionobject, weakreflist),
};
static PyObject *
-make_union(PyObject *args)
+make_union(unionbuilder *ub)
{
- assert(PyTuple_CheckExact(args));
+ Py_ssize_t n = PyList_GET_SIZE(ub->args);
+ if (n == 0) {
+ PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
+ unionbuilder_finalize(ub);
+ return NULL;
+ }
+ if (n == 1) {
+ PyObject *result = PyList_GET_ITEM(ub->args, 0);
+ Py_INCREF(result);
+ unionbuilder_finalize(ub);
+ return result;
+ }
+
+ PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
+ args = PyList_AsTuple(ub->args);
+ if (args == NULL) {
+ goto error;
+ }
+ hashable_args = PyFrozenSet_New(ub->hashable_args);
+ if (hashable_args == NULL) {
+ goto error;
+ }
+ if (ub->unhashable_args != NULL) {
+ unhashable_args = PyList_AsTuple(ub->unhashable_args);
+ if (unhashable_args == NULL) {
+ goto error;
+ }
+ }
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
if (result == NULL) {
- return NULL;
+ goto error;
}
+ unionbuilder_finalize(ub);
result->parameters = NULL;
- result->args = Py_NewRef(args);
+ result->args = args;
+ result->hashable_args = hashable_args;
+ result->unhashable_args = unhashable_args;
+ result->weakreflist = NULL;
_PyObject_GC_TRACK(result);
return (PyObject*)result;
+error:
+ Py_XDECREF(args);
+ Py_XDECREF(hashable_args);
+ Py_XDECREF(unhashable_args);
+ unionbuilder_finalize(ub);
+ return NULL;
}