aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--Modules/_zstd/_zstdmodule.c123
1 files changed, 55 insertions, 68 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c
index 0294828aa10..b2e4f95b906 100644
--- a/Modules/_zstd/_zstdmodule.c
+++ b/Modules/_zstd/_zstdmodule.c
@@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
return (_zstd_state *)state;
}
+static Py_ssize_t
+calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
+ size_t **chunk_sizes)
+{
+ Py_ssize_t chunks_number;
+ Py_ssize_t sizes_sum;
+ Py_ssize_t i;
+
+ chunks_number = Py_SIZE(samples_sizes);
+ if ((size_t) chunks_number > UINT32_MAX) {
+ PyErr_Format(PyExc_ValueError,
+ "The number of samples should be <= %u.", UINT32_MAX);
+ return -1;
+ }
+
+ /* Prepare chunk_sizes */
+ *chunk_sizes = PyMem_New(size_t, chunks_number);
+ if (*chunk_sizes == NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ sizes_sum = 0;
+ for (i = 0; i < chunks_number; i++) {
+ PyObject *size = PyTuple_GetItem(samples_sizes, i);
+ (*chunk_sizes)[i] = PyLong_AsSize_t(size);
+ if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
+ PyErr_Format(PyExc_ValueError,
+ "Items in samples_sizes should be an int "
+ "object, with a value between 0 and %u.", SIZE_MAX);
+ return -1;
+ }
+ sizes_sum += (*chunk_sizes)[i];
+ }
+
+ if (sizes_sum != Py_SIZE(samples_bytes)) {
+ PyErr_SetString(PyExc_ValueError,
+ "The samples size tuple doesn't match the concatenation's size.");
+ return -1;
+ }
+ return chunks_number;
+}
+
/*[clinic input]
_zstd.train_dict
@@ -192,14 +235,10 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
PyObject *samples_sizes, Py_ssize_t dict_size)
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
{
- // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
- // are pretty similar. We should see if we can refactor them to share that code.
- Py_ssize_t chunks_number;
- size_t *chunk_sizes = NULL;
PyObject *dst_dict_bytes = NULL;
+ size_t *chunk_sizes = NULL;
+ Py_ssize_t chunks_number;
size_t zstd_ret;
- Py_ssize_t sizes_sum;
- Py_ssize_t i;
/* Check arguments */
if (dict_size <= 0) {
@@ -207,39 +246,14 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
return NULL;
}
- chunks_number = Py_SIZE(samples_sizes);
- if ((size_t) chunks_number > UINT32_MAX) {
- PyErr_Format(PyExc_ValueError,
- "The number of samples should be <= %u.", UINT32_MAX);
+ /* Check that the samples are valid and get their sizes */
+ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
+ &chunk_sizes);
+ if (chunks_number < 0)
+ {
return NULL;
}
- /* Prepare chunk_sizes */
- chunk_sizes = PyMem_New(size_t, chunks_number);
- if (chunk_sizes == NULL) {
- PyErr_NoMemory();
- goto error;
- }
-
- sizes_sum = 0;
- for (i = 0; i < chunks_number; i++) {
- PyObject *size = PyTuple_GetItem(samples_sizes, i);
- chunk_sizes[i] = PyLong_AsSize_t(size);
- if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Items in samples_sizes should be an int "
- "object, with a value between 0 and %u.", SIZE_MAX);
- goto error;
- }
- sizes_sum += chunk_sizes[i];
- }
-
- if (sizes_sum != Py_SIZE(samples_bytes)) {
- PyErr_SetString(PyExc_ValueError,
- "The samples size tuple doesn't match the concatenation's size.");
- goto error;
- }
-
/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {
@@ -307,8 +321,6 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
PyObject *dst_dict_bytes = NULL;
size_t zstd_ret;
ZDICT_params_t params;
- Py_ssize_t sizes_sum;
- Py_ssize_t i;
/* Check arguments */
if (dict_size <= 0) {
@@ -316,39 +328,14 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
return NULL;
}
- chunks_number = Py_SIZE(samples_sizes);
- if ((size_t) chunks_number > UINT32_MAX) {
- PyErr_Format(PyExc_ValueError,
- "The number of samples should be <= %u.", UINT32_MAX);
+ /* Check that the samples are valid and get their sizes */
+ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
+ &chunk_sizes);
+ if (chunks_number < 0)
+ {
return NULL;
}
- /* Prepare chunk_sizes */
- chunk_sizes = PyMem_New(size_t, chunks_number);
- if (chunk_sizes == NULL) {
- PyErr_NoMemory();
- goto error;
- }
-
- sizes_sum = 0;
- for (i = 0; i < chunks_number; i++) {
- PyObject *size = PyTuple_GetItem(samples_sizes, i);
- chunk_sizes[i] = PyLong_AsSize_t(size);
- if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Items in samples_sizes should be an int "
- "object, with a value between 0 and %u.", SIZE_MAX);
- goto error;
- }
- sizes_sum += chunk_sizes[i];
- }
-
- if (sizes_sum != Py_SIZE(samples_bytes)) {
- PyErr_SetString(PyExc_ValueError,
- "The samples size tuple doesn't match the concatenation's size.");
- goto error;
- }
-
/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {