diff options
-rw-r--r-- | py/mpqstrraw.h | 1 | ||||
-rw-r--r-- | py/obj.h | 3 | ||||
-rw-r--r-- | py/objfilter.c | 53 | ||||
-rw-r--r-- | py/py.mk | 1 | ||||
-rw-r--r-- | py/runtime.c | 1 | ||||
-rw-r--r-- | tests/basics/tests/filter.py | 2 |
6 files changed, 61 insertions, 0 deletions
diff --git a/py/mpqstrraw.h b/py/mpqstrraw.h index bbca1bb43e..f6b4444a70 100644 --- a/py/mpqstrraw.h +++ b/py/mpqstrraw.h @@ -42,6 +42,7 @@ Q(complex) Q(dict) Q(divmod) Q(enumerate) +Q(filter) Q(float) Q(hash) Q(int) @@ -300,6 +300,9 @@ extern const mp_obj_type_t map_type; // enumerate extern const mp_obj_type_t enumerate_type; +// filter +extern const mp_obj_type_t filter_type; + // dict extern const mp_obj_type_t dict_type; uint mp_obj_dict_len(mp_obj_t self_in); diff --git a/py/objfilter.c b/py/objfilter.c new file mode 100644 index 0000000000..6696ffe32c --- /dev/null +++ b/py/objfilter.c @@ -0,0 +1,53 @@ +#include <stdlib.h> +#include <assert.h> + +#include "misc.h" +#include "mpconfig.h" +#include "obj.h" +#include "runtime.h" + +typedef struct _mp_obj_filter_t { + mp_obj_base_t base; + mp_obj_t fun; + mp_obj_t iter; +} mp_obj_filter_t; + +static mp_obj_t filter_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) { + /* NOTE: args are backwards */ + mp_obj_filter_t *o = m_new_obj(mp_obj_filter_t); + assert(n_args == 2); + o->base.type = &filter_type; + o->fun = args[1]; + o->iter = rt_getiter(args[0]); + return o; +} + +static mp_obj_t filter_getiter(mp_obj_t self_in) { + return self_in; +} + +static mp_obj_t filter_iternext(mp_obj_t self_in) { + assert(MP_OBJ_IS_TYPE(self_in, &filter_type)); + mp_obj_filter_t *self = self_in; + mp_obj_t next; + while ((next = rt_iternext(self->iter)) != mp_const_stop_iteration) { + mp_obj_t val; + if (self->fun != mp_const_none) { + val = rt_call_function_n(self->fun, 1, &next); + } else { + val = next; + } + if (rt_is_true(val)) { + return next; + } + } + return mp_const_stop_iteration; +} + +const mp_obj_type_t filter_type = { + { &mp_const_type }, + "filter", + .make_new = filter_make_new, + .getiter = filter_getiter, + .iternext = filter_iternext, +}; @@ -79,6 +79,7 @@ PY_O_BASENAME = \ objdict.o \ objenumerate.o \ objexcept.o \ + objfilter.o \ objfloat.o \ objfun.o \ objgenerator.o \ diff --git a/py/runtime.c b/py/runtime.c index c5292dc263..04c098de3c 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -106,6 +106,7 @@ void rt_init(void) { #endif mp_map_add_qstr(&map_builtins, MP_QSTR_dict, (mp_obj_t)&dict_type); mp_map_add_qstr(&map_builtins, MP_QSTR_enumerate, (mp_obj_t)&enumerate_type); + mp_map_add_qstr(&map_builtins, MP_QSTR_filter, (mp_obj_t)&filter_type); #if MICROPY_ENABLE_FLOAT mp_map_add_qstr(&map_builtins, MP_QSTR_float, (mp_obj_t)&float_type); #endif diff --git a/tests/basics/tests/filter.py b/tests/basics/tests/filter.py new file mode 100644 index 0000000000..5883e3d00b --- /dev/null +++ b/tests/basics/tests/filter.py @@ -0,0 +1,2 @@ +print(list(filter(lambda x: x & 1, range(-3, 4)))) +print(list(filter(None, range(-3, 4)))) |