diff options
Diffstat (limited to 'py/objlist.c')
-rw-r--r-- | py/objlist.c | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/py/objlist.c b/py/objlist.c index df9e1974f9..5162fa09ff 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -8,6 +8,7 @@ #include "mpconfig.h" #include "mpqstr.h" #include "obj.h" +#include "map.h" #include "runtime0.h" #include "runtime.h" @@ -120,14 +121,15 @@ static mp_obj_t list_pop(int n_args, const mp_obj_t *args) { } // TODO make this conform to CPython's definition of sort -static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn) { +static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool reversed) { + int op = reversed ? RT_COMPARE_OP_MORE : RT_COMPARE_OP_LESS; while (head < tail) { mp_obj_t *h = head - 1; mp_obj_t *t = tail; - mp_obj_t v = rt_call_function_1(key_fn, tail[0]); // get pivot using key_fn + mp_obj_t v = key_fn == NULL ? tail[0] : rt_call_function_1(key_fn, tail[0]); // get pivot using key_fn for (;;) { - do ++h; while (rt_compare_op(RT_COMPARE_OP_LESS, rt_call_function_1(key_fn, h[0]), v) == mp_const_true); - do --t; while (h < t && rt_compare_op(RT_COMPARE_OP_LESS, v, rt_call_function_1(key_fn, t[0])) == mp_const_true); + do ++h; while (rt_compare_op(op, key_fn == NULL ? h[0] : rt_call_function_1(key_fn, h[0]), v) == mp_const_true); + do --t; while (h < t && rt_compare_op(op, v, key_fn == NULL ? t[0] : rt_call_function_1(key_fn, t[0])) == mp_const_true); if (h >= t) break; mp_obj_t x = h[0]; h[0] = t[0]; @@ -136,16 +138,31 @@ static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn) { mp_obj_t x = h[0]; h[0] = tail[0]; tail[0] = x; - mp_quicksort(head, t, key_fn); + mp_quicksort(head, t, key_fn, reversed); head = h + 1; } } -static mp_obj_t list_sort(mp_obj_t self_in, mp_obj_t key_fn) { - assert(MP_OBJ_IS_TYPE(self_in, &list_type)); - mp_obj_list_t *self = self_in; +static mp_obj_t list_sort(mp_obj_t *args, mp_map_t *kwargs) { + mp_obj_t *args_items = NULL; + machine_uint_t args_len = 0; + qstr key_idx = qstr_from_str_static("key"); + qstr reverse_idx = qstr_from_str_static("reverse"); + + assert(MP_OBJ_IS_TYPE(args, &tuple_type)); + mp_obj_tuple_get(args, &args_len, &args_items); + assert(args_len >= 1); + if (args_len > 1) { + nlr_jump(mp_obj_new_exception_msg(MP_QSTR_TypeError, + "list.sort takes no positional arguments")); + } + mp_obj_list_t *self = args_items[0]; if (self->len > 1) { - mp_quicksort(self->items, self->items + self->len - 1, key_fn); + mp_map_elem_t *keyfun = mp_qstr_map_lookup(kwargs, key_idx, false); + mp_map_elem_t *reverse = mp_qstr_map_lookup(kwargs, reverse_idx, false); + mp_quicksort(self->items, self->items + self->len - 1, + keyfun ? keyfun->value : NULL, + reverse && reverse->value ? rt_is_true(reverse->value) : false); } return mp_const_none; // return None, as per CPython } @@ -259,7 +276,7 @@ static MP_DEFINE_CONST_FUN_OBJ_3(list_insert_obj, list_insert); static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(list_pop_obj, 1, 2, list_pop); static MP_DEFINE_CONST_FUN_OBJ_2(list_remove_obj, list_remove); static MP_DEFINE_CONST_FUN_OBJ_1(list_reverse_obj, list_reverse); -static MP_DEFINE_CONST_FUN_OBJ_2(list_sort_obj, list_sort); +static MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, list_sort); static const mp_method_t list_type_methods[] = { { "append", &list_append_obj }, |