summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objset.c6
-rw-r--r--tests/basics/tests/set_binop.py51
2 files changed, 31 insertions, 26 deletions
diff --git a/py/objset.c b/py/objset.c
index 2ed2abb611..e41f2c47f4 100644
--- a/py/objset.c
+++ b/py/objset.c
@@ -27,6 +27,10 @@ static mp_obj_t set_it_iternext(mp_obj_t self_in);
void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in) {
mp_obj_set_t *self = self_in;
+ if (self->set.used == 0) {
+ print(env, "set()");
+ return;
+ }
bool first = true;
print(env, "{");
for (int i = 0; i < self->set.alloc; i++) {
@@ -122,7 +126,7 @@ static mp_obj_t set_copy(mp_obj_t self_in) {
mp_obj_set_t *other = m_new_obj(mp_obj_set_t);
other->base.type = &set_type;
- mp_set_init(&other->set, self->set.alloc);
+ mp_set_init(&other->set, self->set.alloc - 1);
other->set.used = self->set.used;
memcpy(other->set.table, self->set.table, self->set.alloc * sizeof(mp_obj_t));
diff --git a/tests/basics/tests/set_binop.py b/tests/basics/tests/set_binop.py
index 46ecbcb63e..d0d0b8027b 100644
--- a/tests/basics/tests/set_binop.py
+++ b/tests/basics/tests/set_binop.py
@@ -1,29 +1,30 @@
def r(s):
l = list(s)
l.sort()
- print(l)
-s = {1, 2}
-t = {2, 3}
-r(s | t)
-r(s ^ t)
-r(s & t)
-r(s - t)
-u = s.copy()
-u |= t
-r(u)
-u = s.copy()
-u ^= t
-r(u)
-u = s.copy()
-u &= t
-r(u)
-u = s.copy()
-u -= t
-r(u)
+ return l
+sets = [set(), {1}, {1, 2}, {1, 2, 3}, {2, 3}, {2, 3, 5}, {5}, {7}]
+for s in sets:
+ for t in sets:
+ print(s, '|', t, '=', r(s | t))
+ print(s, '^', t, '=', r(s ^ t))
+ print(s, '&', t, '=', r(s & t))
+ print(s, '-', t, '=', r(s - t))
+ u = s.copy()
+ u |= t
+ print(s, "|=", t, '-->', r(u))
+ u = s.copy()
+ u ^= t
+ print(s, "^=", t, '-->', r(u))
+ u = s.copy()
+ u &= t
+ print(s, "&=", t, "-->", r(u))
+ u = s.copy()
+ u -= t
+ print(s, "-=", t, "-->", r(u))
-print(s == t)
-print(s != t)
-print(s > t)
-print(s >= t)
-print(s < t)
-print(s <= t)
+ print(s, '==', t, '=', s == t)
+ print(s, '!=', t, '=', s != t)
+ print(s, '>', t, '=', s > t)
+ print(s, '>=', t, '=', s >= t)
+ print(s, '<', t, '=', s < t)
+ print(s, '<=', t, '=', s <= t)