summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objtype.c14
-rw-r--r--tests/basics/class_new.py12
2 files changed, 23 insertions, 3 deletions
diff --git a/py/objtype.c b/py/objtype.c
index f812a0e86c..7689e42b25 100644
--- a/py/objtype.c
+++ b/py/objtype.c
@@ -46,6 +46,8 @@
#define DEBUG_printf(...) (void)0
#endif
+STATIC mp_obj_t static_class_method_make_new(mp_obj_t self_in, uint n_args, uint n_kw, const mp_obj_t *args);
+
/******************************************************************************/
// instance object
@@ -749,6 +751,8 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict)
assert(MP_OBJ_IS_TYPE(bases_tuple, &mp_type_tuple)); // Micro Python restriction, for now
assert(MP_OBJ_IS_TYPE(locals_dict, &mp_type_dict)); // Micro Python restriction, for now
+ // TODO might need to make a copy of locals_dict; at least that's how CPython does it
+
// Basic validation of base classes
uint len;
mp_obj_t *items;
@@ -783,6 +787,16 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict)
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "multiple bases have instance lay-out conflict"));
}
+ mp_map_t *locals_map = mp_obj_dict_get_map(o->locals_dict);
+ mp_map_elem_t *elem = mp_map_lookup(locals_map, MP_OBJ_NEW_QSTR(MP_QSTR___new__), MP_MAP_LOOKUP);
+ if (elem != NULL) {
+ // __new__ slot exists; check if it is a function
+ if (MP_OBJ_IS_TYPE(elem->value, &mp_type_fun_native) || MP_OBJ_IS_TYPE(elem->value, &mp_type_fun_bc)) {
+ // __new__ is a function, wrap it in a staticmethod decorator
+ elem->value = static_class_method_make_new((mp_obj_t)&mp_type_staticmethod, 1, 0, &elem->value);
+ }
+ }
+
return o;
}
diff --git a/tests/basics/class_new.py b/tests/basics/class_new.py
index 7fedcab6c2..7e84dccf40 100644
--- a/tests/basics/class_new.py
+++ b/tests/basics/class_new.py
@@ -1,6 +1,4 @@
class A:
-
- @staticmethod
def __new__(cls):
print("A.__new__")
return super(cls, A).__new__(cls)
@@ -9,13 +7,21 @@ class A:
pass
def meth(self):
- pass
+ print('A.meth')
#print(A.__new__)
#print(A.__init__)
a = A()
+a.meth()
+
+a = A.__new__(A)
+a.meth()
#print(a.meth)
#print(a.__init__)
#print(a.__new__)
+
+# __new__ should automatically be a staticmethod, so this should work
+a = a.__new__(A)
+a.meth()