diff options
Diffstat (limited to 'py/objlist.c')
-rw-r--r-- | py/objlist.c | 68 |
1 files changed, 43 insertions, 25 deletions
diff --git a/py/objlist.c b/py/objlist.c index 02a6b1525b..5162fa09ff 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -8,6 +8,7 @@ #include "mpconfig.h" #include "mpqstr.h" #include "obj.h" +#include "map.h" #include "runtime0.h" #include "runtime.h" @@ -57,6 +58,7 @@ static mp_obj_t list_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args default: nlr_jump(mp_obj_new_exception_msg_1_arg(MP_QSTR_TypeError, "list takes at most 1 argument, %d given", (void*)(machine_int_t)n_args)); } + return NULL; } static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { @@ -119,14 +121,15 @@ static mp_obj_t list_pop(int n_args, const mp_obj_t *args) { } // TODO make this conform to CPython's definition of sort -static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn) { +static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool reversed) { + int op = reversed ? RT_COMPARE_OP_MORE : RT_COMPARE_OP_LESS; while (head < tail) { mp_obj_t *h = head - 1; mp_obj_t *t = tail; - mp_obj_t v = rt_call_function_1(key_fn, tail[0]); // get pivot using key_fn + mp_obj_t v = key_fn == NULL ? tail[0] : rt_call_function_1(key_fn, tail[0]); // get pivot using key_fn for (;;) { - do ++h; while (rt_compare_op(RT_COMPARE_OP_LESS, rt_call_function_1(key_fn, h[0]), v) == mp_const_true); - do --t; while (h < t && rt_compare_op(RT_COMPARE_OP_LESS, v, rt_call_function_1(key_fn, t[0])) == mp_const_true); + do ++h; while (rt_compare_op(op, key_fn == NULL ? h[0] : rt_call_function_1(key_fn, h[0]), v) == mp_const_true); + do --t; while (h < t && rt_compare_op(op, v, key_fn == NULL ? t[0] : rt_call_function_1(key_fn, t[0])) == mp_const_true); if (h >= t) break; mp_obj_t x = h[0]; h[0] = t[0]; @@ -135,16 +138,31 @@ static void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn) { mp_obj_t x = h[0]; h[0] = tail[0]; tail[0] = x; - mp_quicksort(head, t, key_fn); + mp_quicksort(head, t, key_fn, reversed); head = h + 1; } } -static mp_obj_t list_sort(mp_obj_t self_in, mp_obj_t key_fn) { - assert(MP_OBJ_IS_TYPE(self_in, &list_type)); - mp_obj_list_t *self = self_in; +static mp_obj_t list_sort(mp_obj_t *args, mp_map_t *kwargs) { + mp_obj_t *args_items = NULL; + machine_uint_t args_len = 0; + qstr key_idx = qstr_from_str_static("key"); + qstr reverse_idx = qstr_from_str_static("reverse"); + + assert(MP_OBJ_IS_TYPE(args, &tuple_type)); + mp_obj_tuple_get(args, &args_len, &args_items); + assert(args_len >= 1); + if (args_len > 1) { + nlr_jump(mp_obj_new_exception_msg(MP_QSTR_TypeError, + "list.sort takes no positional arguments")); + } + mp_obj_list_t *self = args_items[0]; if (self->len > 1) { - mp_quicksort(self->items, self->items + self->len - 1, key_fn); + mp_map_elem_t *keyfun = mp_qstr_map_lookup(kwargs, key_idx, false); + mp_map_elem_t *reverse = mp_qstr_map_lookup(kwargs, reverse_idx, false); + mp_quicksort(self->items, self->items + self->len - 1, + keyfun ? keyfun->value : NULL, + reverse && reverse->value ? rt_is_true(reverse->value) : false); } return mp_const_none; // return None, as per CPython } @@ -258,29 +276,30 @@ static MP_DEFINE_CONST_FUN_OBJ_3(list_insert_obj, list_insert); static MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(list_pop_obj, 1, 2, list_pop); static MP_DEFINE_CONST_FUN_OBJ_2(list_remove_obj, list_remove); static MP_DEFINE_CONST_FUN_OBJ_1(list_reverse_obj, list_reverse); -static MP_DEFINE_CONST_FUN_OBJ_2(list_sort_obj, list_sort); +static MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, list_sort); + +static const mp_method_t list_type_methods[] = { + { "append", &list_append_obj }, + { "clear", &list_clear_obj }, + { "copy", &list_copy_obj }, + { "count", &list_count_obj }, + { "index", &list_index_obj }, + { "insert", &list_insert_obj }, + { "pop", &list_pop_obj }, + { "remove", &list_remove_obj }, + { "reverse", &list_reverse_obj }, + { "sort", &list_sort_obj }, + { NULL, NULL }, // end-of-list sentinel +}; const mp_obj_type_t list_type = { { &mp_const_type }, "list", .print = list_print, .make_new = list_make_new, - .unary_op = NULL, .binary_op = list_binary_op, .getiter = list_getiter, - .methods = { - { "append", &list_append_obj }, - { "clear", &list_clear_obj }, - { "copy", &list_copy_obj }, - { "count", &list_count_obj }, - { "index", &list_index_obj }, - { "insert", &list_insert_obj }, - { "pop", &list_pop_obj }, - { "remove", &list_remove_obj }, - { "reverse", &list_reverse_obj }, - { "sort", &list_sort_obj }, - { NULL, NULL }, // end-of-list sentinel - }, + .methods = list_type_methods, }; static mp_obj_list_t *list_new(uint n) { @@ -344,7 +363,6 @@ static const mp_obj_type_t list_it_type = { { &mp_const_type }, "list_iterator", .iternext = list_it_iternext, - .methods = { { NULL, NULL }, }, }; mp_obj_t mp_obj_new_list_iterator(mp_obj_list_t *list, int cur) { |