summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/objint.c2
-rw-r--r--py/objlist.c3
-rw-r--r--py/objstr.c10
-rw-r--r--py/objtuple.c7
-rw-r--r--tests/basics/bytes_mult.py12
-rw-r--r--tests/basics/list_mult.py10
-rw-r--r--tests/basics/string_mult.py12
-rw-r--r--tests/basics/tuple_mult.py10
8 files changed, 59 insertions, 7 deletions
diff --git a/py/objint.c b/py/objint.c
index c08bf7da6c..d088ae1a80 100644
--- a/py/objint.c
+++ b/py/objint.c
@@ -289,7 +289,7 @@ mp_obj_t mp_obj_int_binary_op_extra_cases(int op, mp_obj_t lhs_in, mp_obj_t rhs_
// true acts as 0
return mp_binary_op(op, lhs_in, MP_OBJ_NEW_SMALL_INT(1));
} else if (op == MP_BINARY_OP_MULTIPLY) {
- if (MP_OBJ_IS_STR(rhs_in) || MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple) || MP_OBJ_IS_TYPE(rhs_in, &mp_type_list)) {
+ if (MP_OBJ_IS_STR(rhs_in) || MP_OBJ_IS_TYPE(rhs_in, &mp_type_bytes) || MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple) || MP_OBJ_IS_TYPE(rhs_in, &mp_type_list)) {
// multiply is commutative for these types, so delegate to them
return mp_binary_op(op, rhs_in, lhs_in);
}
diff --git a/py/objlist.c b/py/objlist.c
index 655a78908e..578e39452a 100644
--- a/py/objlist.c
+++ b/py/objlist.c
@@ -131,6 +131,9 @@ STATIC mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
if (!mp_obj_get_int_maybe(rhs, &n)) {
return MP_OBJ_NULL; // op not supported
}
+ if (n < 0) {
+ n = 0;
+ }
mp_obj_list_t *s = list_new(o->len * n);
mp_seq_multiply(o->items, sizeof(*o->items), o->len, n, s->items);
return s;
diff --git a/py/objstr.c b/py/objstr.c
index 35bb8e749c..e884794591 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -290,10 +290,16 @@ mp_obj_t mp_obj_str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
break;
case MP_BINARY_OP_MULTIPLY: {
- if (!MP_OBJ_IS_SMALL_INT(rhs_in)) {
+ mp_int_t n;
+ if (!mp_obj_get_int_maybe(rhs_in, &n)) {
return MP_OBJ_NULL; // op not supported
}
- int n = MP_OBJ_SMALL_INT_VALUE(rhs_in);
+ if (n <= 0) {
+ if (lhs_type == &mp_type_str) {
+ return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
+ }
+ n = 0;
+ }
byte *data;
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
diff --git a/py/objtuple.c b/py/objtuple.c
index 3dade2f748..377fbf5431 100644
--- a/py/objtuple.c
+++ b/py/objtuple.c
@@ -137,10 +137,13 @@ mp_obj_t mp_obj_tuple_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
return s;
}
case MP_BINARY_OP_MULTIPLY: {
- if (!MP_OBJ_IS_SMALL_INT(rhs)) {
+ mp_int_t n;
+ if (!mp_obj_get_int_maybe(rhs, &n)) {
return MP_OBJ_NULL; // op not supported
}
- int n = MP_OBJ_SMALL_INT_VALUE(rhs);
+ if (n <= 0) {
+ return mp_const_empty_tuple;
+ }
mp_obj_tuple_t *s = mp_obj_new_tuple(o->len * n, NULL);
mp_seq_multiply(o->items, sizeof(*o->items), o->len, n, s->items);
return s;
diff --git a/tests/basics/bytes_mult.py b/tests/basics/bytes_mult.py
new file mode 100644
index 0000000000..0effd938ea
--- /dev/null
+++ b/tests/basics/bytes_mult.py
@@ -0,0 +1,12 @@
+# basic multiplication
+print(b'0' * 5)
+
+# check negative, 0, positive; lhs and rhs multiplication
+for i in (-4, -2, 0, 2, 4):
+ print(i * b'12')
+ print(b'12' * i)
+
+# check that we don't modify existing object
+a = b'123'
+c = a * 3
+print(a, c)
diff --git a/tests/basics/list_mult.py b/tests/basics/list_mult.py
index ec65fbb3f4..16948f74c2 100644
--- a/tests/basics/list_mult.py
+++ b/tests/basics/list_mult.py
@@ -1,4 +1,12 @@
+# basic multiplication
print([0] * 5)
+
+# check negative, 0, positive; lhs and rhs multiplication
+for i in (-4, -2, 0, 2, 4):
+ print(i * [1, 2])
+ print([1, 2] * i)
+
+# check that we don't modify existing list
a = [1, 2, 3]
c = a * 3
-print(c)
+print(a, c)
diff --git a/tests/basics/string_mult.py b/tests/basics/string_mult.py
new file mode 100644
index 0000000000..c0713c1d3a
--- /dev/null
+++ b/tests/basics/string_mult.py
@@ -0,0 +1,12 @@
+# basic multiplication
+print('0' * 5)
+
+# check negative, 0, positive; lhs and rhs multiplication
+for i in (-4, -2, 0, 2, 4):
+ print(i * '12')
+ print('12' * i)
+
+# check that we don't modify existing object
+a = '123'
+c = a * 3
+print(a, c)
diff --git a/tests/basics/tuple_mult.py b/tests/basics/tuple_mult.py
index f8350f2f27..0f52bce44e 100644
--- a/tests/basics/tuple_mult.py
+++ b/tests/basics/tuple_mult.py
@@ -1,4 +1,12 @@
+# basic multiplication
print((0,) * 5)
+
+# check negative, 0, positive; lhs and rhs multiplication
+for i in (-4, -2, 0, 2, 4):
+ print(i * (1, 2))
+ print((1, 2) * i)
+
+# check that we don't modify existing tuple
a = (1, 2, 3)
c = a * 3
-print(c)
+print(a, c)