aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Objects/unionobject.c
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2025-03-04 11:44:19 -0800
committerGitHub <noreply@github.com>2025-03-04 11:44:19 -0800
commitdc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada (patch)
tree045fed4b7965d56ea45c009dad6dddb42d7be8b0 /Objects/unionobject.c
parente091520fdbcfe406e5fdcf66b7864b2b34a6726b (diff)
downloadcpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.tar.gz
cpython-dc6d66f44c0a25b69dfec7e4ffc4a6fa5e4feada.zip
gh-105499: Merge typing.Union and types.UnionType (#105511)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com> Co-authored-by: Ken Jin <kenjin@python.org> Co-authored-by: Carl Meyer <carl@oddbird.net>
Diffstat (limited to 'Objects/unionobject.c')
-rw-r--r--Objects/unionobject.c438
1 files changed, 333 insertions, 105 deletions
diff --git a/Objects/unionobject.c b/Objects/unionobject.c
index 6e65a653a95..065b0b85397 100644
--- a/Objects/unionobject.c
+++ b/Objects/unionobject.c
@@ -1,17 +1,17 @@
-// types.UnionType -- used to represent e.g. Union[int, str], int | str
+// typing.Union -- used to represent e.g. Union[int, str], int | str
#include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
#include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr
#include "pycore_unionobject.h"
-static PyObject *make_union(PyObject *);
-
-
typedef struct {
PyObject_HEAD
- PyObject *args;
+ PyObject *args; // all args (tuple)
+ PyObject *hashable_args; // frozenset or NULL
+ PyObject *unhashable_args; // tuple or NULL
PyObject *parameters;
+ PyObject *weakreflist;
} unionobject;
static void
@@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self)
unionobject *alias = (unionobject *)self;
_PyObject_GC_UNTRACK(self);
+ if (alias->weakreflist != NULL) {
+ PyObject_ClearWeakRefs((PyObject *)alias);
+ }
Py_XDECREF(alias->args);
+ Py_XDECREF(alias->hashable_args);
+ Py_XDECREF(alias->unhashable_args);
Py_XDECREF(alias->parameters);
Py_TYPE(self)->tp_free(self);
}
@@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
{
unionobject *alias = (unionobject *)self;
Py_VISIT(alias->args);
+ Py_VISIT(alias->hashable_args);
+ Py_VISIT(alias->unhashable_args);
Py_VISIT(alias->parameters);
return 0;
}
@@ -39,13 +46,67 @@ static Py_hash_t
union_hash(PyObject *self)
{
unionobject *alias = (unionobject *)self;
- PyObject *args = PyFrozenSet_New(alias->args);
- if (args == NULL) {
- return (Py_hash_t)-1;
+ // If there are any unhashable args, treat this union as unhashable.
+ // Otherwise, two unions might compare equal but have different hashes.
+ if (alias->unhashable_args) {
+ // Attempt to get an error from one of the values.
+ assert(PyTuple_CheckExact(alias->unhashable_args));
+ Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
+ Py_hash_t hash = PyObject_Hash(arg);
+ if (hash == -1) {
+ return -1;
+ }
+ }
+ // The unhashable values somehow became hashable again. Still raise
+ // an error.
+ PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
+ return -1;
}
- Py_hash_t hash = PyObject_Hash(args);
- Py_DECREF(args);
- return hash;
+ return PyObject_Hash(alias->hashable_args);
+}
+
+static int
+unions_equal(unionobject *a, unionobject *b)
+{
+ int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
+ if (result == -1) {
+ return -1;
+ }
+ if (result == 0) {
+ return 0;
+ }
+ if (a->unhashable_args && b->unhashable_args) {
+ Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
+ if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
+ return 0;
+ }
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
+ int result = PySequence_Contains(b->unhashable_args, arg_a);
+ if (result == -1) {
+ return -1;
+ }
+ if (!result) {
+ return 0;
+ }
+ }
+ for (Py_ssize_t i = 0; i < n; i++) {
+ PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
+ int result = PySequence_Contains(a->unhashable_args, arg_b);
+ if (result == -1) {
+ return -1;
+ }
+ if (!result) {
+ return 0;
+ }
+ }
+ }
+ else if (a->unhashable_args || b->unhashable_args) {
+ return 0;
+ }
+ return 1;
}
static PyObject *
@@ -55,93 +116,128 @@ union_richcompare(PyObject *a, PyObject *b, int op)
Py_RETURN_NOTIMPLEMENTED;
}
- PyObject *a_set = PySet_New(((unionobject*)a)->args);
- if (a_set == NULL) {
+ int equal = unions_equal((unionobject*)a, (unionobject*)b);
+ if (equal == -1) {
return NULL;
}
- PyObject *b_set = PySet_New(((unionobject*)b)->args);
- if (b_set == NULL) {
- Py_DECREF(a_set);
- return NULL;
+ if (op == Py_EQ) {
+ return PyBool_FromLong(equal);
+ }
+ else {
+ return PyBool_FromLong(!equal);
}
- PyObject *result = PyObject_RichCompare(a_set, b_set, op);
- Py_DECREF(b_set);
- Py_DECREF(a_set);
- return result;
}
-static int
-is_same(PyObject *left, PyObject *right)
+typedef struct {
+ PyObject *args; // list
+ PyObject *hashable_args; // set
+ PyObject *unhashable_args; // list or NULL
+ bool is_checked; // whether to call type_check()
+} unionbuilder;
+
+static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
+static PyObject *make_union(unionbuilder *);
+static PyObject *type_check(PyObject *, const char *);
+
+static bool
+unionbuilder_init(unionbuilder *ub, bool is_checked)
{
- int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
- return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
+ ub->args = PyList_New(0);
+ if (ub->args == NULL) {
+ return false;
+ }
+ ub->hashable_args = PySet_New(NULL);
+ if (ub->hashable_args == NULL) {
+ Py_DECREF(ub->args);
+ return false;
+ }
+ ub->unhashable_args = NULL;
+ ub->is_checked = is_checked;
+ return true;
}
-static int
-contains(PyObject **items, Py_ssize_t size, PyObject *obj)
+static void
+unionbuilder_finalize(unionbuilder *ub)
{
- for (Py_ssize_t i = 0; i < size; i++) {
- int is_duplicate = is_same(items[i], obj);
- if (is_duplicate) { // -1 or 1
- return is_duplicate;
- }
- }
- return 0;
+ Py_DECREF(ub->args);
+ Py_DECREF(ub->hashable_args);
+ Py_XDECREF(ub->unhashable_args);
}
-static PyObject *
-merge(PyObject **items1, Py_ssize_t size1,
- PyObject **items2, Py_ssize_t size2)
+static bool
+unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
{
- PyObject *tuple = NULL;
- Py_ssize_t pos = 0;
-
- for (Py_ssize_t i = 0; i < size2; i++) {
- PyObject *arg = items2[i];
- int is_duplicate = contains(items1, size1, arg);
- if (is_duplicate < 0) {
- Py_XDECREF(tuple);
- return NULL;
- }
- if (is_duplicate) {
- continue;
+ Py_hash_t hash = PyObject_Hash(arg);
+ if (hash == -1) {
+ PyErr_Clear();
+ if (ub->unhashable_args == NULL) {
+ ub->unhashable_args = PyList_New(0);
+ if (ub->unhashable_args == NULL) {
+ return false;
+ }
}
-
- if (tuple == NULL) {
- tuple = PyTuple_New(size1 + size2 - i);
- if (tuple == NULL) {
- return NULL;
+ else {
+ int contains = PySequence_Contains(ub->unhashable_args, arg);
+ if (contains < 0) {
+ return false;
}
- for (; pos < size1; pos++) {
- PyObject *a = items1[pos];
- PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
+ if (contains == 1) {
+ return true;
}
}
- PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
- pos++;
+ if (PyList_Append(ub->unhashable_args, arg) < 0) {
+ return false;
+ }
}
-
- if (tuple) {
- (void) _PyTuple_Resize(&tuple, pos);
+ else {
+ int contains = PySet_Contains(ub->hashable_args, arg);
+ if (contains < 0) {
+ return false;
+ }
+ if (contains == 1) {
+ return true;
+ }
+ if (PySet_Add(ub->hashable_args, arg) < 0) {
+ return false;
+ }
}
- return tuple;
+ return PyList_Append(ub->args, arg) == 0;
}
-static PyObject **
-get_types(PyObject **obj, Py_ssize_t *size)
+static bool
+unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
{
- if (*obj == Py_None) {
- *obj = (PyObject *)&_PyNone_Type;
+ if (Py_IsNone(arg)) {
+ arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed
}
- if (_PyUnion_Check(*obj)) {
- PyObject *args = ((unionobject *) *obj)->args;
- *size = PyTuple_GET_SIZE(args);
- return &PyTuple_GET_ITEM(args, 0);
+ else if (_PyUnion_Check(arg)) {
+ PyObject *args = ((unionobject *)arg)->args;
+ return unionbuilder_add_tuple(ub, args);
+ }
+ if (ub->is_checked) {
+ PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
+ if (type == NULL) {
+ return false;
+ }
+ bool result = unionbuilder_add_single_unchecked(ub, type);
+ Py_DECREF(type);
+ return result;
}
else {
- *size = 1;
- return obj;
+ return unionbuilder_add_single_unchecked(ub, arg);
+ }
+}
+
+static bool
+unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
+{
+ Py_ssize_t n = PyTuple_GET_SIZE(tuple);
+ for (Py_ssize_t i = 0; i < n; i++) {
+ if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
+ return false;
+ }
}
+ return true;
}
static int
@@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other)
Py_RETURN_NOTIMPLEMENTED;
}
- Py_ssize_t size1, size2;
- PyObject **items1 = get_types(&self, &size1);
- PyObject **items2 = get_types(&other, &size2);
- PyObject *tuple = merge(items1, size1, items2, size2);
- if (tuple == NULL) {
- if (PyErr_Occurred()) {
- return NULL;
- }
- return Py_NewRef(self);
+ unionbuilder ub;
+ // unchecked because we already checked is_unionable()
+ if (!unionbuilder_init(&ub, false)) {
+ return NULL;
+ }
+ if (!unionbuilder_add_single(&ub, self) ||
+ !unionbuilder_add_single(&ub, other)) {
+ unionbuilder_finalize(&ub);
+ return NULL;
}
- PyObject *new_union = make_union(tuple);
- Py_DECREF(tuple);
+ PyObject *new_union = make_union(&ub);
return new_union;
}
@@ -202,6 +297,18 @@ union_repr(PyObject *self)
goto error;
}
}
+
+#if 0
+ PyUnicodeWriter_WriteUTF8(writer, "|args=", 6);
+ PyUnicodeWriter_WriteRepr(writer, alias->args);
+ PyUnicodeWriter_WriteUTF8(writer, "|h=", 3);
+ PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
+ if (alias->unhashable_args) {
+ PyUnicodeWriter_WriteUTF8(writer, "|u=", 3);
+ PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
+ }
+#endif
+
return PyUnicodeWriter_Finish(writer);
error:
@@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item)
return NULL;
}
- PyObject *res;
- Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
- if (nargs == 0) {
- res = make_union(newargs);
- }
- else {
- res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
- for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
- PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
- Py_SETREF(res, PyNumber_Or(res, arg));
- if (res == NULL) {
- break;
- }
- }
- }
+ PyObject *res = _Py_union_from_tuple(newargs);
Py_DECREF(newargs);
return res;
}
@@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused))
return Py_NewRef(alias->parameters);
}
+static PyObject *
+union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
+{
+ return PyUnicode_FromString("Union");
+}
+
+static PyObject *
+union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
+{
+ return Py_NewRef(&_PyUnion_Type);
+}
+
static PyGetSetDef union_properties[] = {
+ {"__name__", union_name, NULL,
+ PyDoc_STR("Name of the type"), NULL},
+ {"__qualname__", union_name, NULL,
+ PyDoc_STR("Qualified name of the type"), NULL},
+ {"__origin__", union_origin, NULL,
+ PyDoc_STR("Always returns the type"), NULL},
{"__parameters__", union_parameters, (setter)NULL,
PyDoc_STR("Type variables in the types.UnionType."), NULL},
{0}
@@ -306,10 +417,88 @@ _Py_union_args(PyObject *self)
return ((unionobject *) self)->args;
}
+static PyObject *
+call_typing_func_object(const char *name, PyObject **args, size_t nargs)
+{
+ PyObject *typing = PyImport_ImportModule("typing");
+ if (typing == NULL) {
+ return NULL;
+ }
+ PyObject *func = PyObject_GetAttrString(typing, name);
+ if (func == NULL) {
+ Py_DECREF(typing);
+ return NULL;
+ }
+ PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
+ Py_DECREF(func);
+ Py_DECREF(typing);
+ return result;
+}
+
+static PyObject *
+type_check(PyObject *arg, const char *msg)
+{
+ if (Py_IsNone(arg)) {
+ // NoneType is immortal, so don't need an INCREF
+ return (PyObject *)Py_TYPE(arg);
+ }
+ // Fast path to avoid calling into typing.py
+ if (is_unionable(arg)) {
+ return Py_NewRef(arg);
+ }
+ PyObject *message_str = PyUnicode_FromString(msg);
+ if (message_str == NULL) {
+ return NULL;
+ }
+ PyObject *args[2] = {arg, message_str};
+ PyObject *result = call_typing_func_object("_type_check", args, 2);
+ Py_DECREF(message_str);
+ return result;
+}
+
+PyObject *
+_Py_union_from_tuple(PyObject *args)
+{
+ unionbuilder ub;
+ if (!unionbuilder_init(&ub, true)) {
+ return NULL;
+ }
+ if (PyTuple_CheckExact(args)) {
+ if (!unionbuilder_add_tuple(&ub, args)) {
+ return NULL;
+ }
+ }
+ else {
+ if (!unionbuilder_add_single(&ub, args)) {
+ return NULL;
+ }
+ }
+ return make_union(&ub);
+}
+
+static PyObject *
+union_class_getitem(PyObject *cls, PyObject *args)
+{
+ return _Py_union_from_tuple(args);
+}
+
+static PyObject *
+union_mro_entries(PyObject *self, PyObject *args)
+{
+ return PyErr_Format(PyExc_TypeError,
+ "Cannot subclass %R", self);
+}
+
+static PyMethodDef union_methods[] = {
+ {"__mro_entries__", union_mro_entries, METH_O},
+ {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
+ {0}
+};
+
PyTypeObject _PyUnion_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- .tp_name = "types.UnionType",
- .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
+ .tp_name = "typing.Union",
+ .tp_doc = PyDoc_STR("Represent a union type\n"
"\n"
"E.g. for int | str"),
.tp_basicsize = sizeof(unionobject),
@@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = {
.tp_hash = union_hash,
.tp_getattro = union_getattro,
.tp_members = union_members,
+ .tp_methods = union_methods,
.tp_richcompare = union_richcompare,
.tp_as_mapping = &union_as_mapping,
.tp_as_number = &union_as_number,
.tp_repr = union_repr,
.tp_getset = union_properties,
+ .tp_weaklistoffset = offsetof(unionobject, weakreflist),
};
static PyObject *
-make_union(PyObject *args)
+make_union(unionbuilder *ub)
{
- assert(PyTuple_CheckExact(args));
+ Py_ssize_t n = PyList_GET_SIZE(ub->args);
+ if (n == 0) {
+ PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
+ unionbuilder_finalize(ub);
+ return NULL;
+ }
+ if (n == 1) {
+ PyObject *result = PyList_GET_ITEM(ub->args, 0);
+ Py_INCREF(result);
+ unionbuilder_finalize(ub);
+ return result;
+ }
+
+ PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
+ args = PyList_AsTuple(ub->args);
+ if (args == NULL) {
+ goto error;
+ }
+ hashable_args = PyFrozenSet_New(ub->hashable_args);
+ if (hashable_args == NULL) {
+ goto error;
+ }
+ if (ub->unhashable_args != NULL) {
+ unhashable_args = PyList_AsTuple(ub->unhashable_args);
+ if (unhashable_args == NULL) {
+ goto error;
+ }
+ }
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
if (result == NULL) {
- return NULL;
+ goto error;
}
+ unionbuilder_finalize(ub);
result->parameters = NULL;
- result->args = Py_NewRef(args);
+ result->args = args;
+ result->hashable_args = hashable_args;
+ result->unhashable_args = unhashable_args;
+ result->weakreflist = NULL;
_PyObject_GC_TRACK(result);
return (PyObject*)result;
+error:
+ Py_XDECREF(args);
+ Py_XDECREF(hashable_args);
+ Py_XDECREF(unhashable_args);
+ unionbuilder_finalize(ub);
+ return NULL;
}