summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/obj.c26
-rw-r--r--tests/basics/builtin_hash.py13
2 files changed, 33 insertions, 6 deletions
diff --git a/py/obj.c b/py/obj.c
index 81adbe3933..1b42377ed8 100644
--- a/py/obj.c
+++ b/py/obj.c
@@ -166,20 +166,34 @@ mp_int_t mp_obj_hash(mp_obj_t o_in) {
return mp_obj_tuple_hash(o_in);
} else if (MP_OBJ_IS_TYPE(o_in, &mp_type_type)) {
return (mp_int_t)o_in;
- } else if (MP_OBJ_IS_OBJ(o_in)) {
+ } else if (mp_obj_is_instance_type(mp_obj_get_type(o_in))) {
// if a valid __hash__ method exists, use it
- mp_obj_t hash_method[2];
- mp_load_method_maybe(o_in, MP_QSTR___hash__, hash_method);
- if (hash_method[0] != MP_OBJ_NULL) {
- mp_obj_t hash_val = mp_call_method_n_kw(0, 0, hash_method);
+ mp_obj_t method[2];
+ mp_load_method_maybe(o_in, MP_QSTR___hash__, method);
+ if (method[0] != MP_OBJ_NULL) {
+ mp_obj_t hash_val = mp_call_method_n_kw(0, 0, method);
if (MP_OBJ_IS_INT(hash_val)) {
return mp_obj_int_get_truncated(hash_val);
}
+ goto error;
}
+
+ mp_load_method_maybe(o_in, MP_QSTR___eq__, method);
+ if (method[0] == MP_OBJ_NULL) {
+ // https://docs.python.org/3/reference/datamodel.html#object.__hash__
+ // "User-defined classes have __eq__() and __hash__() methods by default;
+ // with them, all objects compare unequal (except with themselves) and
+ // x.__hash__() returns an appropriate value such that x == y implies
+ // both that x is y and hash(x) == hash(y)."
+ return (mp_int_t)o_in;
+ }
+ // "A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
+ // When the __hash__() method of a class is None, instances of the class will raise an appropriate TypeError"
}
- // TODO hash class and instances - in CPython by default user created classes' __hash__ resolves to their id
+ // TODO hash classes
+error:
if (MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE) {
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "unhashable type"));
} else {
diff --git a/tests/basics/builtin_hash.py b/tests/basics/builtin_hash.py
index 0abfe980e1..d7615c3ec0 100644
--- a/tests/basics/builtin_hash.py
+++ b/tests/basics/builtin_hash.py
@@ -19,3 +19,16 @@ class A:
print(hash(A()))
print({A():1})
+
+class B:
+ pass
+hash(B())
+
+
+class C:
+ def __eq__(self, another):
+ return True
+try:
+ hash(C())
+except TypeError:
+ print("TypeError")