summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/obj.h1
-rw-r--r--py/objtuple.c14
-rw-r--r--py/objtype.c9
-rw-r--r--tests/basics/subclass_native_cmp.py9
4 files changed, 30 insertions, 3 deletions
diff --git a/py/obj.h b/py/obj.h
index 5757810c96..30a60b77d0 100644
--- a/py/obj.h
+++ b/py/obj.h
@@ -401,6 +401,7 @@ mp_obj_t mp_obj_new_module(qstr module_name);
mp_obj_type_t *mp_obj_get_type(mp_const_obj_t o_in);
const char *mp_obj_get_type_str(mp_const_obj_t o_in);
bool mp_obj_is_subclass_fast(mp_const_obj_t object, mp_const_obj_t classinfo); // arguments should be type objects
+mp_obj_t mp_instance_cast_to_native_base(mp_const_obj_t self_in, mp_const_obj_t native_type);
void mp_obj_print_helper(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t o_in, mp_print_kind_t kind);
void mp_obj_print(mp_obj_t o, mp_print_kind_t kind);
diff --git a/py/objtuple.c b/py/objtuple.c
index 7d4e87755f..ca65b28e31 100644
--- a/py/objtuple.c
+++ b/py/objtuple.c
@@ -99,12 +99,20 @@ mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m
// Don't pass MP_BINARY_OP_NOT_EQUAL here
STATIC bool tuple_cmp_helper(int op, mp_obj_t self_in, mp_obj_t another_in) {
- assert(MP_OBJ_IS_TYPE(self_in, &mp_type_tuple));
- if (!MP_OBJ_IS_TYPE(another_in, &mp_type_tuple)) {
- return false;
+ mp_obj_type_t *self_type = mp_obj_get_type(self_in);
+ if (self_type->getiter != tuple_getiter) {
+ assert(0);
}
+ mp_obj_type_t *another_type = mp_obj_get_type(another_in);
mp_obj_tuple_t *self = self_in;
mp_obj_tuple_t *another = another_in;
+ if (another_type->getiter != tuple_getiter) {
+ // Slow path for user subclasses
+ another = mp_instance_cast_to_native_base(another, &mp_type_tuple);
+ if (another == MP_OBJ_NULL) {
+ return false;
+ }
+ }
return mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len);
}
diff --git a/py/objtype.c b/py/objtype.c
index ef5f6b9d9c..c579477db7 100644
--- a/py/objtype.c
+++ b/py/objtype.c
@@ -845,6 +845,15 @@ STATIC mp_obj_t mp_builtin_isinstance(mp_obj_t object, mp_obj_t classinfo) {
MP_DEFINE_CONST_FUN_OBJ_2(mp_builtin_isinstance_obj, mp_builtin_isinstance);
+mp_obj_t mp_instance_cast_to_native_base(mp_const_obj_t self_in, mp_const_obj_t native_type) {
+ mp_obj_type_t *self_type = mp_obj_get_type(self_in);
+ if (!mp_obj_is_subclass_fast(self_type, native_type)) {
+ return MP_OBJ_NULL;
+ }
+ mp_obj_instance_t *self = (mp_obj_instance_t*)self_in;
+ return self->subobj[0];
+}
+
/******************************************************************************/
// staticmethod and classmethod types (probably should go in a different file)
diff --git a/tests/basics/subclass_native_cmp.py b/tests/basics/subclass_native_cmp.py
new file mode 100644
index 0000000000..1a095bfa1a
--- /dev/null
+++ b/tests/basics/subclass_native_cmp.py
@@ -0,0 +1,9 @@
+# Test calling non-special method inherited from native type
+
+class mytuple(tuple):
+ pass
+
+t = mytuple((1, 2, 3))
+print(t)
+print(t == (1, 2, 3))
+print((1, 2, 3) == t)