summaryrefslogtreecommitdiffstatshomepage
path: root/py
diff options
context:
space:
mode:
Diffstat (limited to 'py')
-rw-r--r--py/compile.c12
-rw-r--r--py/obj.h1
-rw-r--r--py/objlist.c7
-rw-r--r--py/runtime.c64
-rw-r--r--py/runtime.h1
-rw-r--r--py/vm.c6
6 files changed, 91 insertions, 0 deletions
diff --git a/py/compile.c b/py/compile.c
index 1fc5f07227..95fe1d759f 100644
--- a/py/compile.c
+++ b/py/compile.c
@@ -1933,6 +1933,11 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
// optimisation for a, b = c, d; to match CPython's optimisation
mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0];
mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0];
+ if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr)
+ || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr)) {
+ // can't optimise when it's a star expression on the lhs
+ goto no_optimisation;
+ }
compile_node(comp, pns10->nodes[0]); // rhs
compile_node(comp, pns10->nodes[1]); // rhs
EMIT(rot_two);
@@ -1945,6 +1950,12 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
// optimisation for a, b, c = d, e, f; to match CPython's optimisation
mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0];
mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0];
+ if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr)
+ || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr)
+ || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[2], PN_star_expr)) {
+ // can't optimise when it's a star expression on the lhs
+ goto no_optimisation;
+ }
compile_node(comp, pns10->nodes[0]); // rhs
compile_node(comp, pns10->nodes[1]); // rhs
compile_node(comp, pns10->nodes[2]); // rhs
@@ -1954,6 +1965,7 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
c_assign(comp, pns0->nodes[1], ASSIGN_STORE); // lhs store
c_assign(comp, pns0->nodes[2], ASSIGN_STORE); // lhs store
} else {
+ no_optimisation:
compile_node(comp, pns1->nodes[0]); // rhs
c_assign(comp, pns->nodes[0], ASSIGN_STORE); // lhs store
}
diff --git a/py/obj.h b/py/obj.h
index fc99055b6e..77cf7838ee 100644
--- a/py/obj.h
+++ b/py/obj.h
@@ -440,6 +440,7 @@ mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m
// list
mp_obj_t mp_obj_list_append(mp_obj_t self_in, mp_obj_t arg);
void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items);
+void mp_obj_list_set_len(mp_obj_t self_in, uint len);
void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value);
mp_obj_t mp_obj_list_sort(uint n_args, const mp_obj_t *args, mp_map_t *kwargs);
diff --git a/py/objlist.c b/py/objlist.c
index 620bf2944a..371d1cb26e 100644
--- a/py/objlist.c
+++ b/py/objlist.c
@@ -378,6 +378,13 @@ void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items) {
*items = self->items;
}
+void mp_obj_list_set_len(mp_obj_t self_in, uint len) {
+ // trust that the caller knows what it's doing
+ // TODO realloc if len got much smaller than alloc
+ mp_obj_list_t *self = self_in;
+ self->len = len;
+}
+
void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
mp_obj_list_t *self = self_in;
uint i = mp_get_index(self->base.type, self->len, index, false);
diff --git a/py/runtime.c b/py/runtime.c
index 44e0ded507..3d1ae72c2f 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -672,6 +672,70 @@ too_long:
nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "too many values to unpack (expected %d)", num));
}
+// unpacked items are stored in reverse order into the array pointed to by items
+void mp_unpack_ex(mp_obj_t seq_in, uint num_in, mp_obj_t *items) {
+ uint num_left = num_in & 0xff;
+ uint num_right = (num_in >> 8) & 0xff;
+ DEBUG_OP_printf("unpack ex %d %d\n", num_left, num_right);
+ uint seq_len;
+ if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple) || MP_OBJ_IS_TYPE(seq_in, &mp_type_list)) {
+ mp_obj_t *seq_items;
+ if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple)) {
+ mp_obj_tuple_get(seq_in, &seq_len, &seq_items);
+ } else {
+ if (num_left == 0 && num_right == 0) {
+ // *a, = b # sets a to b if b is a list
+ items[0] = seq_in;
+ return;
+ }
+ mp_obj_list_get(seq_in, &seq_len, &seq_items);
+ }
+ if (seq_len < num_left + num_right) {
+ goto too_short;
+ }
+ for (uint i = 0; i < num_right; i++) {
+ items[i] = seq_items[seq_len - 1 - i];
+ }
+ items[num_right] = mp_obj_new_list(seq_len - num_left - num_right, seq_items + num_left);
+ for (uint i = 0; i < num_left; i++) {
+ items[num_right + 1 + i] = seq_items[num_left - 1 - i];
+ }
+ } else {
+ // Generic iterable; this gets a bit messy: we unpack known left length to the
+ // items destination array, then the rest to a dynamically created list. Once the
+ // iterable is exhausted, we take from this list for the right part of the items.
+ // TODO Improve to waste less memory in the dynamically created list.
+ mp_obj_t iterable = mp_getiter(seq_in);
+ mp_obj_t item;
+ for (seq_len = 0; seq_len < num_left; seq_len++) {
+ item = mp_iternext(iterable);
+ if (item == MP_OBJ_NULL) {
+ goto too_short;
+ }
+ items[num_left + num_right + 1 - 1 - seq_len] = item;
+ }
+ mp_obj_t rest = mp_obj_new_list(0, NULL);
+ while ((item = mp_iternext(iterable)) != MP_OBJ_NULL) {
+ mp_obj_list_append(rest, item);
+ }
+ uint rest_len;
+ mp_obj_t *rest_items;
+ mp_obj_list_get(rest, &rest_len, &rest_items);
+ if (rest_len < num_right) {
+ goto too_short;
+ }
+ items[num_right] = rest;
+ for (uint i = 0; i < num_right; i++) {
+ items[num_right - 1 - i] = rest_items[rest_len - num_right + i];
+ }
+ mp_obj_list_set_len(rest, rest_len - num_right);
+ }
+ return;
+
+too_short:
+ nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "need more than %d values to unpack", seq_len));
+}
+
mp_obj_t mp_load_attr(mp_obj_t base, qstr attr) {
DEBUG_OP_printf("load attr %p.%s\n", base, qstr_str(attr));
// use load_method
diff --git a/py/runtime.h b/py/runtime.h
index cc76186f4e..ab34be2da9 100644
--- a/py/runtime.h
+++ b/py/runtime.h
@@ -44,6 +44,7 @@ mp_obj_t mp_call_method_n_kw(uint n_args, uint n_kw, const mp_obj_t *args);
mp_obj_t mp_call_method_n_kw_var(bool have_self, uint n_args_n_kw, const mp_obj_t *args);
void mp_unpack_sequence(mp_obj_t seq, uint num, mp_obj_t *items);
+void mp_unpack_ex(mp_obj_t seq, uint num, mp_obj_t *items);
mp_obj_t mp_store_map(mp_obj_t map, mp_obj_t key, mp_obj_t value);
mp_obj_t mp_load_attr(mp_obj_t base, qstr attr);
void mp_load_method(mp_obj_t base, qstr attr, mp_obj_t *dest);
diff --git a/py/vm.c b/py/vm.c
index 2e64cd9573..869a9381ad 100644
--- a/py/vm.c
+++ b/py/vm.c
@@ -653,6 +653,12 @@ unwind_jump:
sp += unum - 1;
break;
+ case MP_BC_UNPACK_EX:
+ DECODE_UINT;
+ mp_unpack_ex(sp[0], unum, sp);
+ sp += (unum & 0xff) + ((unum >> 8) & 0xff);
+ break;
+
case MP_BC_MAKE_FUNCTION:
DECODE_UINT;
PUSH(mp_make_function_from_id(unum, MP_OBJ_NULL, MP_OBJ_NULL));