diff options
author | Jelle Zijlstra <jelle.zijlstra@gmail.com> | 2025-03-04 11:44:19 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-04 11:44:19 -0800 |
commit | dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada (patch) | |
tree | 045fed4b7965d56ea45c009dad6dddb42d7be8b0 /Objects/unionobject.c | |
parent | e091520fdbcfe406e5fdcf66b7864b2b34a6726b (diff) | |
download | cpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.tar.gz cpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.zip |
gh-105499: Merge typing.Union and types.UnionType (#105511)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ken Jin <kenjin@python.org>
Co-authored-by: Carl Meyer <carl@oddbird.net>
Diffstat (limited to 'Objects/unionobject.c')
-rw-r--r-- | Objects/unionobject.c | 438 |
1 files changed, 333 insertions, 105 deletions
diff --git a/Objects/unionobject.c b/Objects/unionobject.c index 6e65a653a95..065b0b85397 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -1,17 +1,17 @@ -// types.UnionType -- used to represent e.g. Union[int, str], int | str +// typing.Union -- used to represent e.g. Union[int, str], int | str #include "Python.h" #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK #include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr #include "pycore_unionobject.h" -static PyObject *make_union(PyObject *); - - typedef struct { PyObject_HEAD - PyObject *args; + PyObject *args; // all args (tuple) + PyObject *hashable_args; // frozenset or NULL + PyObject *unhashable_args; // tuple or NULL PyObject *parameters; + PyObject *weakreflist; } unionobject; static void @@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self) unionobject *alias = (unionobject *)self; _PyObject_GC_UNTRACK(self); + if (alias->weakreflist != NULL) { + PyObject_ClearWeakRefs((PyObject *)alias); + } Py_XDECREF(alias->args); + Py_XDECREF(alias->hashable_args); + Py_XDECREF(alias->unhashable_args); Py_XDECREF(alias->parameters); Py_TYPE(self)->tp_free(self); } @@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg) { unionobject *alias = (unionobject *)self; Py_VISIT(alias->args); + Py_VISIT(alias->hashable_args); + Py_VISIT(alias->unhashable_args); Py_VISIT(alias->parameters); return 0; } @@ -39,13 +46,67 @@ static Py_hash_t union_hash(PyObject *self) { unionobject *alias = (unionobject *)self; - PyObject *args = PyFrozenSet_New(alias->args); - if (args == NULL) { - return (Py_hash_t)-1; + // If there are any unhashable args, treat this union as unhashable. + // Otherwise, two unions might compare equal but have different hashes. + if (alias->unhashable_args) { + // Attempt to get an error from one of the values. + assert(PyTuple_CheckExact(alias->unhashable_args)); + Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args); + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i); + Py_hash_t hash = PyObject_Hash(arg); + if (hash == -1) { + return -1; + } + } + // The unhashable values somehow became hashable again. Still raise + // an error. + PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n); + return -1; } - Py_hash_t hash = PyObject_Hash(args); - Py_DECREF(args); - return hash; + return PyObject_Hash(alias->hashable_args); +} + +static int +unions_equal(unionobject *a, unionobject *b) +{ + int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ); + if (result == -1) { + return -1; + } + if (result == 0) { + return 0; + } + if (a->unhashable_args && b->unhashable_args) { + Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args); + if (n != PyTuple_GET_SIZE(b->unhashable_args)) { + return 0; + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i); + int result = PySequence_Contains(b->unhashable_args, arg_a); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i); + int result = PySequence_Contains(a->unhashable_args, arg_b); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + } + else if (a->unhashable_args || b->unhashable_args) { + return 0; + } + return 1; } static PyObject * @@ -55,93 +116,128 @@ union_richcompare(PyObject *a, PyObject *b, int op) Py_RETURN_NOTIMPLEMENTED; } - PyObject *a_set = PySet_New(((unionobject*)a)->args); - if (a_set == NULL) { + int equal = unions_equal((unionobject*)a, (unionobject*)b); + if (equal == -1) { return NULL; } - PyObject *b_set = PySet_New(((unionobject*)b)->args); - if (b_set == NULL) { - Py_DECREF(a_set); - return NULL; + if (op == Py_EQ) { + return PyBool_FromLong(equal); + } + else { + return PyBool_FromLong(!equal); } - PyObject *result = PyObject_RichCompare(a_set, b_set, op); - Py_DECREF(b_set); - Py_DECREF(a_set); - return result; } -static int -is_same(PyObject *left, PyObject *right) +typedef struct { + PyObject *args; // list + PyObject *hashable_args; // set + PyObject *unhashable_args; // list or NULL + bool is_checked; // whether to call type_check() +} unionbuilder; + +static bool unionbuilder_add_tuple(unionbuilder *, PyObject *); +static PyObject *make_union(unionbuilder *); +static PyObject *type_check(PyObject *, const char *); + +static bool +unionbuilder_init(unionbuilder *ub, bool is_checked) { - int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right); - return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right; + ub->args = PyList_New(0); + if (ub->args == NULL) { + return false; + } + ub->hashable_args = PySet_New(NULL); + if (ub->hashable_args == NULL) { + Py_DECREF(ub->args); + return false; + } + ub->unhashable_args = NULL; + ub->is_checked = is_checked; + return true; } -static int -contains(PyObject **items, Py_ssize_t size, PyObject *obj) +static void +unionbuilder_finalize(unionbuilder *ub) { - for (Py_ssize_t i = 0; i < size; i++) { - int is_duplicate = is_same(items[i], obj); - if (is_duplicate) { // -1 or 1 - return is_duplicate; - } - } - return 0; + Py_DECREF(ub->args); + Py_DECREF(ub->hashable_args); + Py_XDECREF(ub->unhashable_args); } -static PyObject * -merge(PyObject **items1, Py_ssize_t size1, - PyObject **items2, Py_ssize_t size2) +static bool +unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg) { - PyObject *tuple = NULL; - Py_ssize_t pos = 0; - - for (Py_ssize_t i = 0; i < size2; i++) { - PyObject *arg = items2[i]; - int is_duplicate = contains(items1, size1, arg); - if (is_duplicate < 0) { - Py_XDECREF(tuple); - return NULL; - } - if (is_duplicate) { - continue; + Py_hash_t hash = PyObject_Hash(arg); + if (hash == -1) { + PyErr_Clear(); + if (ub->unhashable_args == NULL) { + ub->unhashable_args = PyList_New(0); + if (ub->unhashable_args == NULL) { + return false; + } } - - if (tuple == NULL) { - tuple = PyTuple_New(size1 + size2 - i); - if (tuple == NULL) { - return NULL; + else { + int contains = PySequence_Contains(ub->unhashable_args, arg); + if (contains < 0) { + return false; } - for (; pos < size1; pos++) { - PyObject *a = items1[pos]; - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a)); + if (contains == 1) { + return true; } } - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg)); - pos++; + if (PyList_Append(ub->unhashable_args, arg) < 0) { + return false; + } } - - if (tuple) { - (void) _PyTuple_Resize(&tuple, pos); + else { + int contains = PySet_Contains(ub->hashable_args, arg); + if (contains < 0) { + return false; + } + if (contains == 1) { + return true; + } + if (PySet_Add(ub->hashable_args, arg) < 0) { + return false; + } } - return tuple; + return PyList_Append(ub->args, arg) == 0; } -static PyObject ** -get_types(PyObject **obj, Py_ssize_t *size) +static bool +unionbuilder_add_single(unionbuilder *ub, PyObject *arg) { - if (*obj == Py_None) { - *obj = (PyObject *)&_PyNone_Type; + if (Py_IsNone(arg)) { + arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed } - if (_PyUnion_Check(*obj)) { - PyObject *args = ((unionobject *) *obj)->args; - *size = PyTuple_GET_SIZE(args); - return &PyTuple_GET_ITEM(args, 0); + else if (_PyUnion_Check(arg)) { + PyObject *args = ((unionobject *)arg)->args; + return unionbuilder_add_tuple(ub, args); + } + if (ub->is_checked) { + PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type."); + if (type == NULL) { + return false; + } + bool result = unionbuilder_add_single_unchecked(ub, type); + Py_DECREF(type); + return result; } else { - *size = 1; - return obj; + return unionbuilder_add_single_unchecked(ub, arg); + } +} + +static bool +unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple) +{ + Py_ssize_t n = PyTuple_GET_SIZE(tuple); + for (Py_ssize_t i = 0; i < n; i++) { + if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) { + return false; + } } + return true; } static int @@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other) Py_RETURN_NOTIMPLEMENTED; } - Py_ssize_t size1, size2; - PyObject **items1 = get_types(&self, &size1); - PyObject **items2 = get_types(&other, &size2); - PyObject *tuple = merge(items1, size1, items2, size2); - if (tuple == NULL) { - if (PyErr_Occurred()) { - return NULL; - } - return Py_NewRef(self); + unionbuilder ub; + // unchecked because we already checked is_unionable() + if (!unionbuilder_init(&ub, false)) { + return NULL; + } + if (!unionbuilder_add_single(&ub, self) || + !unionbuilder_add_single(&ub, other)) { + unionbuilder_finalize(&ub); + return NULL; } - PyObject *new_union = make_union(tuple); - Py_DECREF(tuple); + PyObject *new_union = make_union(&ub); return new_union; } @@ -202,6 +297,18 @@ union_repr(PyObject *self) goto error; } } + +#if 0 + PyUnicodeWriter_WriteUTF8(writer, "|args=", 6); + PyUnicodeWriter_WriteRepr(writer, alias->args); + PyUnicodeWriter_WriteUTF8(writer, "|h=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->hashable_args); + if (alias->unhashable_args) { + PyUnicodeWriter_WriteUTF8(writer, "|u=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args); + } +#endif + return PyUnicodeWriter_Finish(writer); error: @@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } - PyObject *res; - Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); - if (nargs == 0) { - res = make_union(newargs); - } - else { - res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0)); - for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { - PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); - Py_SETREF(res, PyNumber_Or(res, arg)); - if (res == NULL) { - break; - } - } - } + PyObject *res = _Py_union_from_tuple(newargs); Py_DECREF(newargs); return res; } @@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused)) return Py_NewRef(alias->parameters); } +static PyObject * +union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) +{ + return PyUnicode_FromString("Union"); +} + +static PyObject * +union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) +{ + return Py_NewRef(&_PyUnion_Type); +} + static PyGetSetDef union_properties[] = { + {"__name__", union_name, NULL, + PyDoc_STR("Name of the type"), NULL}, + {"__qualname__", union_name, NULL, + PyDoc_STR("Qualified name of the type"), NULL}, + {"__origin__", union_origin, NULL, + PyDoc_STR("Always returns the type"), NULL}, {"__parameters__", union_parameters, (setter)NULL, PyDoc_STR("Type variables in the types.UnionType."), NULL}, {0} @@ -306,10 +417,88 @@ _Py_union_args(PyObject *self) return ((unionobject *) self)->args; } +static PyObject * +call_typing_func_object(const char *name, PyObject **args, size_t nargs) +{ + PyObject *typing = PyImport_ImportModule("typing"); + if (typing == NULL) { + return NULL; + } + PyObject *func = PyObject_GetAttrString(typing, name); + if (func == NULL) { + Py_DECREF(typing); + return NULL; + } + PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL); + Py_DECREF(func); + Py_DECREF(typing); + return result; +} + +static PyObject * +type_check(PyObject *arg, const char *msg) +{ + if (Py_IsNone(arg)) { + // NoneType is immortal, so don't need an INCREF + return (PyObject *)Py_TYPE(arg); + } + // Fast path to avoid calling into typing.py + if (is_unionable(arg)) { + return Py_NewRef(arg); + } + PyObject *message_str = PyUnicode_FromString(msg); + if (message_str == NULL) { + return NULL; + } + PyObject *args[2] = {arg, message_str}; + PyObject *result = call_typing_func_object("_type_check", args, 2); + Py_DECREF(message_str); + return result; +} + +PyObject * +_Py_union_from_tuple(PyObject *args) +{ + unionbuilder ub; + if (!unionbuilder_init(&ub, true)) { + return NULL; + } + if (PyTuple_CheckExact(args)) { + if (!unionbuilder_add_tuple(&ub, args)) { + return NULL; + } + } + else { + if (!unionbuilder_add_single(&ub, args)) { + return NULL; + } + } + return make_union(&ub); +} + +static PyObject * +union_class_getitem(PyObject *cls, PyObject *args) +{ + return _Py_union_from_tuple(args); +} + +static PyObject * +union_mro_entries(PyObject *self, PyObject *args) +{ + return PyErr_Format(PyExc_TypeError, + "Cannot subclass %R", self); +} + +static PyMethodDef union_methods[] = { + {"__mro_entries__", union_mro_entries, METH_O}, + {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")}, + {0} +}; + PyTypeObject _PyUnion_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - .tp_name = "types.UnionType", - .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n" + .tp_name = "typing.Union", + .tp_doc = PyDoc_STR("Represent a union type\n" "\n" "E.g. for int | str"), .tp_basicsize = sizeof(unionobject), @@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = { .tp_hash = union_hash, .tp_getattro = union_getattro, .tp_members = union_members, + .tp_methods = union_methods, .tp_richcompare = union_richcompare, .tp_as_mapping = &union_as_mapping, .tp_as_number = &union_as_number, .tp_repr = union_repr, .tp_getset = union_properties, + .tp_weaklistoffset = offsetof(unionobject, weakreflist), }; static PyObject * -make_union(PyObject *args) +make_union(unionbuilder *ub) { - assert(PyTuple_CheckExact(args)); + Py_ssize_t n = PyList_GET_SIZE(ub->args); + if (n == 0) { + PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types."); + unionbuilder_finalize(ub); + return NULL; + } + if (n == 1) { + PyObject *result = PyList_GET_ITEM(ub->args, 0); + Py_INCREF(result); + unionbuilder_finalize(ub); + return result; + } + + PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL; + args = PyList_AsTuple(ub->args); + if (args == NULL) { + goto error; + } + hashable_args = PyFrozenSet_New(ub->hashable_args); + if (hashable_args == NULL) { + goto error; + } + if (ub->unhashable_args != NULL) { + unhashable_args = PyList_AsTuple(ub->unhashable_args); + if (unhashable_args == NULL) { + goto error; + } + } unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { - return NULL; + goto error; } + unionbuilder_finalize(ub); result->parameters = NULL; - result->args = Py_NewRef(args); + result->args = args; + result->hashable_args = hashable_args; + result->unhashable_args = unhashable_args; + result->weakreflist = NULL; _PyObject_GC_TRACK(result); return (PyObject*)result; +error: + Py_XDECREF(args); + Py_XDECREF(hashable_args); + Py_XDECREF(unhashable_args); + unionbuilder_finalize(ub); + return NULL; } |