summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objlist.c64
-rw-r--r--tests/basics/tests/list_compare.py50
2 files changed, 114 insertions, 0 deletions
diff --git a/py/objlist.c b/py/objlist.c
index c153d2222b..fa8ec67d09 100644
--- a/py/objlist.c
+++ b/py/objlist.c
@@ -62,6 +62,61 @@ static mp_obj_t list_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args
return NULL;
}
+// Don't pass RT_COMPARE_OP_NOT_EQUAL here
+static bool list_cmp_helper(int op, mp_obj_t self_in, mp_obj_t another_in) {
+ assert(MP_OBJ_IS_TYPE(self_in, &list_type));
+ if (!MP_OBJ_IS_TYPE(another_in, &list_type)) {
+ return false;
+ }
+ mp_obj_list_t *self = self_in;
+ mp_obj_list_t *another = another_in;
+ if (op == RT_COMPARE_OP_EQUAL && self->len != another->len) {
+ return false;
+ }
+
+ // Let's deal only with > & >=
+ if (op == RT_COMPARE_OP_LESS || op == RT_COMPARE_OP_LESS_EQUAL) {
+ mp_obj_t t = self;
+ self = another;
+ another = t;
+ if (op == RT_COMPARE_OP_LESS) {
+ op = RT_COMPARE_OP_MORE;
+ } else {
+ op = RT_COMPARE_OP_MORE_EQUAL;
+ }
+ }
+
+ int len = self->len < another->len ? self->len : another->len;
+ bool eq_status = true; // empty lists are equal
+ bool rel_status;
+ for (int i = 0; i < len; i++) {
+ eq_status = mp_obj_equal(self->items[i], another->items[i]);
+ if (op == RT_COMPARE_OP_EQUAL && !eq_status) {
+ return false;
+ }
+ rel_status = (rt_binary_op(op, self->items[i], another->items[i]) == mp_const_true);
+ if (!eq_status && !rel_status) {
+ return false;
+ }
+ }
+
+ // If we had tie in the last element...
+ if (eq_status) {
+ // ... and we have lists of different lengths...
+ if (self->len != another->len) {
+ if (self->len < another->len) {
+ // ... then longer list length wins (we deal only with >)
+ return false;
+ }
+ } else if (op == RT_COMPARE_OP_MORE) {
+ // Otherwise, if we have strict relation, equality means failure
+ return false;
+ }
+ }
+
+ return true;
+}
+
static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
mp_obj_list_t *o = lhs;
switch (op) {
@@ -105,6 +160,15 @@ static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
}
return s;
}
+ case RT_COMPARE_OP_EQUAL:
+ case RT_COMPARE_OP_LESS:
+ case RT_COMPARE_OP_LESS_EQUAL:
+ case RT_COMPARE_OP_MORE:
+ case RT_COMPARE_OP_MORE_EQUAL:
+ return MP_BOOL(list_cmp_helper(op, lhs, rhs));
+ case RT_COMPARE_OP_NOT_EQUAL:
+ return MP_BOOL(!list_cmp_helper(RT_COMPARE_OP_EQUAL, lhs, rhs));
+
default:
// op not supported
return NULL;
diff --git a/tests/basics/tests/list_compare.py b/tests/basics/tests/list_compare.py
new file mode 100644
index 0000000000..eea8814247
--- /dev/null
+++ b/tests/basics/tests/list_compare.py
@@ -0,0 +1,50 @@
+print([] == [])
+print([] > [])
+print([] < [])
+print([] == [1])
+print([1] == [])
+print([] > [1])
+print([1] > [])
+print([] < [1])
+print([1] < [])
+print([] >= [1])
+print([1] >= [])
+print([] <= [1])
+print([1] <= [])
+
+print([1] == [1])
+print([1] != [1])
+print([1] == [2])
+print([1] == [1, 0])
+
+print([1] > [1])
+print([1] > [2])
+print([2] > [1])
+print([1, 0] > [1])
+print([1, -1] > [1])
+print([1] > [1, 0])
+print([1] > [1, -1])
+
+print([1] < [1])
+print([2] < [1])
+print([1] < [2])
+print([1] < [1, 0])
+print([1] < [1, -1])
+print([1, 0] < [1])
+print([1, -1] < [1])
+
+print([1] >= [1])
+print([1] >= [2])
+print([2] >= [1])
+print([1, 0] >= [1])
+print([1, -1] >= [1])
+print([1] >= [1, 0])
+print([1] >= [1, -1])
+
+print([1] <= [1])
+print([2] <= [1])
+print([1] <= [2])
+print([1] <= [1, 0])
+print([1] <= [1, -1])
+print([1, 0] <= [1])
+print([1, -1] <= [1])