diff options
Diffstat (limited to 'py/objset.c')
-rw-r--r-- | py/objset.c | 69 |
1 files changed, 46 insertions, 23 deletions
diff --git a/py/objset.c b/py/objset.c index fc124fcd8c..f74bc74a07 100644 --- a/py/objset.c +++ b/py/objset.c @@ -44,7 +44,7 @@ typedef struct _mp_obj_set_it_t { mp_obj_base_t base; mp_fun_1_t iternext; mp_obj_set_t *set; - mp_uint_t cur; + size_t cur; } mp_obj_set_it_t; STATIC mp_obj_t set_it_iternext(mp_obj_t self_in); @@ -96,7 +96,7 @@ STATIC void set_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t } #endif mp_print_str(print, "{"); - for (mp_uint_t i = 0; i < self->set.alloc; i++) { + for (size_t i = 0; i < self->set.alloc; i++) { if (MP_SET_SLOT_IS_FILLED(&self->set, i)) { if (!first) { mp_print_str(print, ", "); @@ -129,7 +129,7 @@ STATIC mp_obj_t set_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_ default: { // can only be 0 or 1 arg // 1 argument, an iterable from which we make a new set mp_obj_t set = mp_obj_new_set(0, NULL); - mp_obj_t iterable = mp_getiter(args[0]); + mp_obj_t iterable = mp_getiter(args[0], NULL); mp_obj_t item; while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { mp_obj_set_store(set, item); @@ -143,10 +143,10 @@ STATIC mp_obj_t set_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_ STATIC mp_obj_t set_it_iternext(mp_obj_t self_in) { mp_obj_set_it_t *self = MP_OBJ_TO_PTR(self_in); - mp_uint_t max = self->set->set.alloc; + size_t max = self->set->set.alloc; mp_set_t *set = &self->set->set; - for (mp_uint_t i = self->cur; i < max; i++) { + for (size_t i = self->cur; i < max; i++) { if (MP_SET_SLOT_IS_FILLED(set, i)) { self->cur = i + 1; return set->table[i]; @@ -156,8 +156,9 @@ STATIC mp_obj_t set_it_iternext(mp_obj_t self_in) { return MP_OBJ_STOP_ITERATION; } -STATIC mp_obj_t set_getiter(mp_obj_t set_in) { - mp_obj_set_it_t *o = m_new_obj(mp_obj_set_it_t); +STATIC mp_obj_t set_getiter(mp_obj_t set_in, mp_obj_iter_buf_t *iter_buf) { + assert(sizeof(mp_obj_set_it_t) <= sizeof(mp_obj_iter_buf_t)); + mp_obj_set_it_t *o = (mp_obj_set_it_t*)iter_buf; o->base.type = &mp_type_polymorph_iter; o->iternext = set_it_iternext; o->set = (mp_obj_set_t *)MP_OBJ_TO_PTR(set_in); @@ -228,12 +229,12 @@ STATIC mp_obj_t set_diff_int(size_t n_args, const mp_obj_t *args, bool update) { } - for (mp_uint_t i = 1; i < n_args; i++) { + for (size_t i = 1; i < n_args; i++) { mp_obj_t other = args[i]; if (self == other) { set_clear(self); } else { - mp_obj_t iter = mp_getiter(other); + mp_obj_t iter = mp_getiter(other, NULL); mp_obj_t next; while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { set_discard(self, next); @@ -270,7 +271,7 @@ STATIC mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update) mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in); mp_obj_set_t *out = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL)); - mp_obj_t iter = mp_getiter(other); + mp_obj_t iter = mp_getiter(other, NULL); mp_obj_t next; while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { if (mp_set_lookup(&self->set, next, MP_MAP_LOOKUP)) { @@ -302,7 +303,8 @@ STATIC mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) { check_set_or_frozenset(self_in); mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in); - mp_obj_t iter = mp_getiter(other); + mp_obj_iter_buf_t iter_buf; + mp_obj_t iter = mp_getiter(other, &iter_buf); mp_obj_t next; while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { if (mp_set_lookup(&self->set, next, MP_MAP_LOOKUP)) { @@ -335,7 +337,8 @@ STATIC mp_obj_t set_issubset_internal(mp_obj_t self_in, mp_obj_t other_in, bool if (proper && self->set.used == other->set.used) { out = false; } else { - mp_obj_t iter = set_getiter(MP_OBJ_FROM_PTR(self)); + mp_obj_iter_buf_t iter_buf; + mp_obj_t iter = set_getiter(MP_OBJ_FROM_PTR(self), &iter_buf); mp_obj_t next; while ((next = set_it_iternext(iter)) != MP_OBJ_STOP_ITERATION) { if (!mp_set_lookup(&other->set, next, MP_MAP_LOOKUP)) { @@ -408,7 +411,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(set_remove_obj, set_remove); STATIC mp_obj_t set_symmetric_difference_update(mp_obj_t self_in, mp_obj_t other_in) { check_set(self_in); mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in); - mp_obj_t iter = mp_getiter(other_in); + mp_obj_t iter = mp_getiter(other_in, NULL); mp_obj_t next; while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { mp_set_lookup(&self->set, next, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND_OR_REMOVE_IF_FOUND); @@ -427,7 +430,7 @@ STATIC mp_obj_t set_symmetric_difference(mp_obj_t self_in, mp_obj_t other_in) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(set_symmetric_difference_obj, set_symmetric_difference); STATIC void set_update_int(mp_obj_set_t *self, mp_obj_t other_in) { - mp_obj_t iter = mp_getiter(other_in); + mp_obj_t iter = mp_getiter(other_in, NULL); mp_obj_t next; while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { mp_set_lookup(&self->set, next, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND); @@ -436,7 +439,7 @@ STATIC void set_update_int(mp_obj_set_t *self, mp_obj_t other_in) { STATIC mp_obj_t set_update(size_t n_args, const mp_obj_t *args) { check_set(args[0]); - for (mp_uint_t i = 1; i < n_args; i++) { + for (size_t i = 1; i < n_args; i++) { set_update_int(MP_OBJ_TO_PTR(args[0]), args[i]); } @@ -462,10 +465,10 @@ STATIC mp_obj_t set_unary_op(mp_uint_t op, mp_obj_t self_in) { if (MP_OBJ_IS_TYPE(self_in, &mp_type_frozenset)) { // start hash with unique value mp_int_t hash = (mp_int_t)(uintptr_t)&mp_type_frozenset; - mp_uint_t max = self->set.alloc; + size_t max = self->set.alloc; mp_set_t *set = &self->set; - for (mp_uint_t i = 0; i < max; i++) { + for (size_t i = 0; i < max; i++) { if (MP_SET_SLOT_IS_FILLED(set, i)) { hash += MP_OBJ_SMALL_INT_VALUE(mp_unary_op(MP_UNARY_OP_HASH, set->table[i])); } @@ -479,6 +482,11 @@ STATIC mp_obj_t set_unary_op(mp_uint_t op, mp_obj_t self_in) { STATIC mp_obj_t set_binary_op(mp_uint_t op, mp_obj_t lhs, mp_obj_t rhs) { mp_obj_t args[] = {lhs, rhs}; + #if MICROPY_PY_BUILTINS_FROZENSET + bool update = MP_OBJ_IS_TYPE(lhs, &mp_type_set); + #else + bool update = true; + #endif switch (op) { case MP_BINARY_OP_OR: return set_union(lhs, rhs); @@ -489,13 +497,28 @@ STATIC mp_obj_t set_binary_op(mp_uint_t op, mp_obj_t lhs, mp_obj_t rhs) { case MP_BINARY_OP_SUBTRACT: return set_diff(2, args); case MP_BINARY_OP_INPLACE_OR: - return set_union(lhs, rhs); + if (update) { + set_update(2, args); + return lhs; + } else { + return set_union(lhs, rhs); + } case MP_BINARY_OP_INPLACE_XOR: - return set_symmetric_difference(lhs, rhs); + if (update) { + set_symmetric_difference_update(lhs, rhs); + return lhs; + } else { + return set_symmetric_difference(lhs, rhs); + } case MP_BINARY_OP_INPLACE_AND: - return set_intersect(lhs, rhs); + rhs = set_intersect_int(lhs, rhs, update); + if (update) { + return lhs; + } else { + return rhs; + } case MP_BINARY_OP_INPLACE_SUBTRACT: - return set_diff(2, args); + return set_diff_int(2, args, update); case MP_BINARY_OP_LESS: return set_issubset_proper(lhs, rhs); case MP_BINARY_OP_MORE: @@ -567,11 +590,11 @@ const mp_obj_type_t mp_type_frozenset = { }; #endif -mp_obj_t mp_obj_new_set(mp_uint_t n_args, mp_obj_t *items) { +mp_obj_t mp_obj_new_set(size_t n_args, mp_obj_t *items) { mp_obj_set_t *o = m_new_obj(mp_obj_set_t); o->base.type = &mp_type_set; mp_set_init(&o->set, n_args); - for (mp_uint_t i = 0; i < n_args; i++) { + for (size_t i = 0; i < n_args; i++) { mp_set_lookup(&o->set, items[i], MP_MAP_LOOKUP_ADD_IF_NOT_FOUND); } return MP_OBJ_FROM_PTR(o); |