diff options
-rw-r--r-- | Modules/_zstd/_zstdmodule.c | 123 |
1 files changed, 55 insertions, 68 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 0294828aa10..b2e4f95b906 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -172,6 +172,49 @@ get_zstd_state(PyObject *module) return (_zstd_state *)state; } +static Py_ssize_t +calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, + size_t **chunk_sizes) +{ + Py_ssize_t chunks_number; + Py_ssize_t sizes_sum; + Py_ssize_t i; + + chunks_number = Py_SIZE(samples_sizes); + if ((size_t) chunks_number > UINT32_MAX) { + PyErr_Format(PyExc_ValueError, + "The number of samples should be <= %u.", UINT32_MAX); + return -1; + } + + /* Prepare chunk_sizes */ + *chunk_sizes = PyMem_New(size_t, chunks_number); + if (*chunk_sizes == NULL) { + PyErr_NoMemory(); + return -1; + } + + sizes_sum = 0; + for (i = 0; i < chunks_number; i++) { + PyObject *size = PyTuple_GetItem(samples_sizes, i); + (*chunk_sizes)[i] = PyLong_AsSize_t(size); + if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { + PyErr_Format(PyExc_ValueError, + "Items in samples_sizes should be an int " + "object, with a value between 0 and %u.", SIZE_MAX); + return -1; + } + sizes_sum += (*chunk_sizes)[i]; + } + + if (sizes_sum != Py_SIZE(samples_bytes)) { + PyErr_SetString(PyExc_ValueError, + "The samples size tuple doesn't match the concatenation's size."); + return -1; + } + return chunks_number; +} + /*[clinic input] _zstd.train_dict @@ -192,14 +235,10 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t dict_size) /*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/ { - // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict - // are pretty similar. We should see if we can refactor them to share that code. - Py_ssize_t chunks_number; - size_t *chunk_sizes = NULL; PyObject *dst_dict_bytes = NULL; + size_t *chunk_sizes = NULL; + Py_ssize_t chunks_number; size_t zstd_ret; - Py_ssize_t sizes_sum; - Py_ssize_t i; /* Check arguments */ if (dict_size <= 0) { @@ -207,39 +246,14 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, return NULL; } - chunks_number = Py_SIZE(samples_sizes); - if ((size_t) chunks_number > UINT32_MAX) { - PyErr_Format(PyExc_ValueError, - "The number of samples should be <= %u.", UINT32_MAX); + /* Check that the samples are valid and get their sizes */ + chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, + &chunk_sizes); + if (chunks_number < 0) + { return NULL; } - /* Prepare chunk_sizes */ - chunk_sizes = PyMem_New(size_t, chunks_number); - if (chunk_sizes == NULL) { - PyErr_NoMemory(); - goto error; - } - - sizes_sum = 0; - for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - chunk_sizes[i] = PyLong_AsSize_t(size); - if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); - goto error; - } - sizes_sum += chunk_sizes[i]; - } - - if (sizes_sum != Py_SIZE(samples_bytes)) { - PyErr_SetString(PyExc_ValueError, - "The samples size tuple doesn't match the concatenation's size."); - goto error; - } - /* Allocate dict buffer */ dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); if (dst_dict_bytes == NULL) { @@ -307,8 +321,6 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, PyObject *dst_dict_bytes = NULL; size_t zstd_ret; ZDICT_params_t params; - Py_ssize_t sizes_sum; - Py_ssize_t i; /* Check arguments */ if (dict_size <= 0) { @@ -316,39 +328,14 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, return NULL; } - chunks_number = Py_SIZE(samples_sizes); - if ((size_t) chunks_number > UINT32_MAX) { - PyErr_Format(PyExc_ValueError, - "The number of samples should be <= %u.", UINT32_MAX); + /* Check that the samples are valid and get their sizes */ + chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, + &chunk_sizes); + if (chunks_number < 0) + { return NULL; } - /* Prepare chunk_sizes */ - chunk_sizes = PyMem_New(size_t, chunks_number); - if (chunk_sizes == NULL) { - PyErr_NoMemory(); - goto error; - } - - sizes_sum = 0; - for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - chunk_sizes[i] = PyLong_AsSize_t(size); - if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); - goto error; - } - sizes_sum += chunk_sizes[i]; - } - - if (sizes_sum != Py_SIZE(samples_bytes)) { - PyErr_SetString(PyExc_ValueError, - "The samples size tuple doesn't match the concatenation's size."); - goto error; - } - /* Allocate dict buffer */ dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); if (dst_dict_bytes == NULL) { |