summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--py/modstruct.c55
-rw-r--r--tests/basics/struct1.py6
-rw-r--r--tests/basics/struct2.py6
3 files changed, 38 insertions, 29 deletions
diff --git a/py/modstruct.c b/py/modstruct.c
index 0d4a45f6b6..61dd0f81b4 100644
--- a/py/modstruct.c
+++ b/py/modstruct.c
@@ -82,26 +82,10 @@ STATIC mp_uint_t get_fmt_num(const char **p) {
return val;
}
-STATIC uint calcsize_items(const char *fmt) {
- uint cnt = 0;
- while (*fmt) {
- int num = 1;
- if (unichar_isdigit(*fmt)) {
- num = get_fmt_num(&fmt);
- if (*fmt == 's') {
- num = 1;
- }
- }
- cnt += num;
- fmt++;
- }
- return cnt;
-}
-
-STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) {
- const char *fmt = mp_obj_str_get_str(fmt_in);
+STATIC size_t calc_size_items(const char *fmt, size_t *total_sz) {
char fmt_type = get_fmt_type(&fmt);
- mp_uint_t size;
+ size_t total_cnt = 0;
+ size_t size;
for (size = 0; *fmt; fmt++) {
mp_uint_t cnt = 1;
if (unichar_isdigit(*fmt)) {
@@ -109,8 +93,10 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) {
}
if (*fmt == 's') {
+ total_cnt += 1;
size += cnt;
} else {
+ total_cnt += cnt;
mp_uint_t align;
size_t sz = mp_binary_get_size(fmt_type, *fmt, &align);
while (cnt--) {
@@ -120,6 +106,14 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) {
}
}
}
+ *total_sz = size;
+ return total_cnt;
+}
+
+STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) {
+ const char *fmt = mp_obj_str_get_str(fmt_in);
+ size_t size;
+ calc_size_items(fmt, &size);
return MP_OBJ_NEW_SMALL_INT(size);
}
MP_DEFINE_CONST_FUN_OBJ_1(struct_calcsize_obj, struct_calcsize);
@@ -130,8 +124,9 @@ STATIC mp_obj_t struct_unpack_from(size_t n_args, const mp_obj_t *args) {
// Since we implement unpack and unpack_from using the same function
// we relax the "exact" requirement, and only implement "big enough".
const char *fmt = mp_obj_str_get_str(args[0]);
+ size_t total_sz;
+ size_t num_items = calc_size_items(fmt, &total_sz);
char fmt_type = get_fmt_type(&fmt);
- uint num_items = calcsize_items(fmt);
mp_obj_tuple_t *res = MP_OBJ_TO_PTR(mp_obj_new_tuple(num_items, NULL));
mp_buffer_info_t bufinfo;
mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ);
@@ -152,21 +147,23 @@ STATIC mp_obj_t struct_unpack_from(size_t n_args, const mp_obj_t *args) {
p += offset;
}
- for (uint i = 0; i < num_items;) {
- mp_uint_t sz = 1;
+ // Check that the input buffer is big enough to unpack all the values
+ if (p + total_sz > end_p) {
+ mp_raise_ValueError("buffer too small");
+ }
+
+ for (size_t i = 0; i < num_items;) {
+ mp_uint_t cnt = 1;
if (unichar_isdigit(*fmt)) {
- sz = get_fmt_num(&fmt);
- }
- if (p + sz > end_p) {
- mp_raise_ValueError("buffer too small");
+ cnt = get_fmt_num(&fmt);
}
mp_obj_t item;
if (*fmt == 's') {
- item = mp_obj_new_bytes(p, sz);
- p += sz;
+ item = mp_obj_new_bytes(p, cnt);
+ p += cnt;
res->items[i++] = item;
} else {
- while (sz--) {
+ while (cnt--) {
item = mp_binary_get_val(fmt_type, *fmt, &p);
res->items[i++] = item;
}
diff --git a/tests/basics/struct1.py b/tests/basics/struct1.py
index a442beb1e5..2cf75137b8 100644
--- a/tests/basics/struct1.py
+++ b/tests/basics/struct1.py
@@ -39,6 +39,12 @@ print(v == (10, 100, 200, 300))
# network byte order
print(struct.pack('!i', 123))
+# check that we get an error if the buffer is too small
+try:
+ struct.unpack('I', b'\x00\x00\x00')
+except:
+ print('struct.error')
+
# first arg must be a string
try:
struct.pack(1, 2)
diff --git a/tests/basics/struct2.py b/tests/basics/struct2.py
index 3b9dd5c1f6..e3336c0c78 100644
--- a/tests/basics/struct2.py
+++ b/tests/basics/struct2.py
@@ -25,6 +25,12 @@ print(struct.calcsize('0s1s0H2H'))
print(struct.unpack('<0s1s0H2H', b'01234'))
print(struct.pack('<0s1s0H2H', b'abc', b'abc', 258, 515))
+# check that we get an error if the buffer is too small
+try:
+ struct.unpack('2H', b'\x00\x00')
+except:
+ print('Exception')
+
# check that unknown types raise an exception
try:
struct.unpack('z', b'1')