diff options
Diffstat (limited to 'Modules/_zstd/_zstdmodule.c')
-rw-r--r-- | Modules/_zstd/_zstdmodule.c | 261 |
1 files changed, 121 insertions, 140 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 0294828aa10..986b3579479 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -28,41 +28,42 @@ set_zstd_error(const _zstd_state* const state, char *msg; assert(ZSTD_isError(zstd_ret)); - switch (type) - { - case ERR_DECOMPRESS: - msg = "Unable to decompress Zstandard data: %s"; - break; - case ERR_COMPRESS: - msg = "Unable to compress Zstandard data: %s"; - break; - - case ERR_LOAD_D_DICT: - msg = "Unable to load Zstandard dictionary or prefix for decompression: %s"; - break; - case ERR_LOAD_C_DICT: - msg = "Unable to load Zstandard dictionary or prefix for compression: %s"; - break; - - case ERR_GET_C_BOUNDS: - msg = "Unable to get zstd compression parameter bounds: %s"; - break; - case ERR_GET_D_BOUNDS: - msg = "Unable to get zstd decompression parameter bounds: %s"; - break; - case ERR_SET_C_LEVEL: - msg = "Unable to set zstd compression level: %s"; - break; - - case ERR_TRAIN_DICT: - msg = "Unable to train the Zstandard dictionary: %s"; - break; - case ERR_FINALIZE_DICT: - msg = "Unable to finalize the Zstandard dictionary: %s"; - break; - - default: - Py_UNREACHABLE(); + switch (type) { + case ERR_DECOMPRESS: + msg = "Unable to decompress Zstandard data: %s"; + break; + case ERR_COMPRESS: + msg = "Unable to compress Zstandard data: %s"; + break; + + case ERR_LOAD_D_DICT: + msg = "Unable to load Zstandard dictionary or prefix for " + "decompression: %s"; + break; + case ERR_LOAD_C_DICT: + msg = "Unable to load Zstandard dictionary or prefix for " + "compression: %s"; + break; + + case ERR_GET_C_BOUNDS: + msg = "Unable to get zstd compression parameter bounds: %s"; + break; + case ERR_GET_D_BOUNDS: + msg = "Unable to get zstd decompression parameter bounds: %s"; + break; + case ERR_SET_C_LEVEL: + msg = "Unable to set zstd compression level: %s"; + break; + + case ERR_TRAIN_DICT: + msg = "Unable to train the Zstandard dictionary: %s"; + break; + case ERR_FINALIZE_DICT: + msg = "Unable to finalize the Zstandard dictionary: %s"; + break; + + default: + Py_UNREACHABLE(); } PyErr_Format(state->ZstdError, msg, ZSTD_getErrorName(zstd_ret)); } @@ -102,16 +103,13 @@ static const ParameterInfo dp_list[] = { }; void -set_parameter_error(const _zstd_state* const state, int is_compress, - int key_v, int value_v) +set_parameter_error(int is_compress, int key_v, int value_v) { ParameterInfo const *list; int list_size; - char const *name; char *type; ZSTD_bounds bounds; - int i; - char pos_msg[128]; + char pos_msg[64]; if (is_compress) { list = cp_list; @@ -125,8 +123,8 @@ set_parameter_error(const _zstd_state* const state, int is_compress, } /* Find parameter's name */ - name = NULL; - for (i = 0; i < list_size; i++) { + char const *name = NULL; + for (int i = 0; i < list_size; i++) { if (key_v == (list+i)->parameter) { name = (list+i)->parameter_name; break; @@ -148,20 +146,16 @@ set_parameter_error(const _zstd_state* const state, int is_compress, bounds = ZSTD_dParam_getBounds(key_v); } if (ZSTD_isError(bounds.error)) { - PyErr_Format(state->ZstdError, - "Invalid zstd %s parameter \"%s\".", + PyErr_Format(PyExc_ValueError, "invalid %s parameter '%s'", type, name); return; } /* Error message */ - PyErr_Format(state->ZstdError, - "Error when setting zstd %s parameter \"%s\", it " - "should %d <= value <= %d, provided value is %d. " - "(%d-bit build)", - type, name, - bounds.lowerBound, bounds.upperBound, value_v, - 8*(int)sizeof(Py_ssize_t)); + PyErr_Format(PyExc_ValueError, + "%s parameter '%s' received an illegal value %d; " + "the valid range is [%d, %d]", + type, name, value_v, bounds.lowerBound, bounds.upperBound); } static inline _zstd_state* @@ -172,6 +166,50 @@ get_zstd_state(PyObject *module) return (_zstd_state *)state; } +static Py_ssize_t +calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, + size_t **chunk_sizes) +{ + Py_ssize_t chunks_number; + Py_ssize_t sizes_sum; + Py_ssize_t i; + + chunks_number = Py_SIZE(samples_sizes); + if ((size_t) chunks_number > UINT32_MAX) { + PyErr_Format(PyExc_ValueError, + "The number of samples should be <= %u.", UINT32_MAX); + return -1; + } + + /* Prepare chunk_sizes */ + *chunk_sizes = PyMem_New(size_t, chunks_number); + if (*chunk_sizes == NULL) { + PyErr_NoMemory(); + return -1; + } + + sizes_sum = 0; + for (i = 0; i < chunks_number; i++) { + PyObject *size = PyTuple_GetItem(samples_sizes, i); + (*chunk_sizes)[i] = PyLong_AsSize_t(size); + if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { + PyErr_Format(PyExc_ValueError, + "Items in samples_sizes should be an int " + "object, with a value between 0 and %u.", SIZE_MAX); + return -1; + } + sizes_sum += (*chunk_sizes)[i]; + } + + if (sizes_sum != Py_SIZE(samples_bytes)) { + PyErr_SetString(PyExc_ValueError, + "The samples size tuple doesn't match the " + "concatenation's size."); + return -1; + } + return chunks_number; +} + /*[clinic input] _zstd.train_dict @@ -192,51 +230,22 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t dict_size) /*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/ { - // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict - // are pretty similar. We should see if we can refactor them to share that code. - Py_ssize_t chunks_number; - size_t *chunk_sizes = NULL; PyObject *dst_dict_bytes = NULL; + size_t *chunk_sizes = NULL; + Py_ssize_t chunks_number; size_t zstd_ret; - Py_ssize_t sizes_sum; - Py_ssize_t i; /* Check arguments */ if (dict_size <= 0) { - PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); - return NULL; - } - - chunks_number = Py_SIZE(samples_sizes); - if ((size_t) chunks_number > UINT32_MAX) { - PyErr_Format(PyExc_ValueError, - "The number of samples should be <= %u.", UINT32_MAX); + PyErr_SetString(PyExc_ValueError, + "dict_size argument should be positive number."); return NULL; } - /* Prepare chunk_sizes */ - chunk_sizes = PyMem_New(size_t, chunks_number); - if (chunk_sizes == NULL) { - PyErr_NoMemory(); - goto error; - } - - sizes_sum = 0; - for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - chunk_sizes[i] = PyLong_AsSize_t(size); - if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); - goto error; - } - sizes_sum += chunk_sizes[i]; - } - - if (sizes_sum != Py_SIZE(samples_bytes)) { - PyErr_SetString(PyExc_ValueError, - "The samples size tuple doesn't match the concatenation's size."); + /* Check that the samples are valid and get their sizes */ + chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, + &chunk_sizes); + if (chunks_number < 0) { goto error; } @@ -257,7 +266,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Check Zstandard dict error */ if (ZDICT_isError(zstd_ret)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_TRAIN_DICT, zstd_ret); goto error; } @@ -307,45 +316,18 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, PyObject *dst_dict_bytes = NULL; size_t zstd_ret; ZDICT_params_t params; - Py_ssize_t sizes_sum; - Py_ssize_t i; /* Check arguments */ if (dict_size <= 0) { - PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); - return NULL; - } - - chunks_number = Py_SIZE(samples_sizes); - if ((size_t) chunks_number > UINT32_MAX) { - PyErr_Format(PyExc_ValueError, - "The number of samples should be <= %u.", UINT32_MAX); + PyErr_SetString(PyExc_ValueError, + "dict_size argument should be positive number."); return NULL; } - /* Prepare chunk_sizes */ - chunk_sizes = PyMem_New(size_t, chunks_number); - if (chunk_sizes == NULL) { - PyErr_NoMemory(); - goto error; - } - - sizes_sum = 0; - for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - chunk_sizes[i] = PyLong_AsSize_t(size); - if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); - goto error; - } - sizes_sum += chunk_sizes[i]; - } - - if (sizes_sum != Py_SIZE(samples_bytes)) { - PyErr_SetString(PyExc_ValueError, - "The samples size tuple doesn't match the concatenation's size."); + /* Check that the samples are valid and get their sizes */ + chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, + &chunk_sizes); + if (chunks_number < 0) { goto error; } @@ -368,14 +350,15 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_finalizeDictionary( PyBytes_AS_STRING(dst_dict_bytes), dict_size, - PyBytes_AS_STRING(custom_dict_bytes), Py_SIZE(custom_dict_bytes), + PyBytes_AS_STRING(custom_dict_bytes), + Py_SIZE(custom_dict_bytes), PyBytes_AS_STRING(samples_bytes), chunk_sizes, (uint32_t)chunks_number, params); Py_END_ALLOW_THREADS /* Check Zstandard dict error */ if (ZDICT_isError(zstd_ret)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_FINALIZE_DICT, zstd_ret); goto error; } @@ -415,7 +398,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress) if (is_compress) { bound = ZSTD_cParam_getBounds(parameter); if (ZSTD_isError(bound.error)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_GET_C_BOUNDS, bound.error); return NULL; } @@ -423,7 +406,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress) else { bound = ZSTD_dParam_getBounds(parameter); if (ZSTD_isError(bound.error)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_GET_D_BOUNDS, bound.error); return NULL; } @@ -448,9 +431,10 @@ _zstd_get_frame_size_impl(PyObject *module, Py_buffer *frame_buffer) { size_t frame_size; - frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, frame_buffer->len); + frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, + frame_buffer->len); if (ZSTD_isError(frame_size)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); PyErr_Format(mod_state->ZstdError, "Error when finding the compressed size of a Zstandard frame. " "Ensure the frame_buffer argument starts from the " @@ -486,7 +470,7 @@ _zstd_get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) /* #define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1) #define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) */ if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); PyErr_SetString(mod_state->ZstdError, "Error when getting information from the header of " "a Zstandard frame. Ensure the frame_buffer argument " @@ -521,7 +505,7 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, PyObject *d_parameter_type) /*[clinic end generated code: output=f3313b1294f19502 input=75d7a953580fae5f]*/ { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { PyErr_SetString(PyExc_ValueError, @@ -530,13 +514,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, return NULL; } - Py_XDECREF(mod_state->CParameter_type); - Py_INCREF(c_parameter_type); - mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; - - Py_XDECREF(mod_state->DParameter_type); - Py_INCREF(d_parameter_type); - mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; + Py_XSETREF( + mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type)); + Py_XSETREF( + mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type)); Py_RETURN_NONE; } @@ -581,7 +562,7 @@ do { \ Py_DECREF(v); \ } while (0) - _zstd_state* const mod_state = get_zstd_state(m); + _zstd_state* mod_state = get_zstd_state(m); /* Reusable objects & variables */ mod_state->CParameter_type = NULL; @@ -687,7 +668,7 @@ do { \ static int _zstd_traverse(PyObject *module, visitproc visit, void *arg) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); Py_VISIT(mod_state->ZstdDict_type); Py_VISIT(mod_state->ZstdCompressor_type); @@ -704,7 +685,7 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg) static int _zstd_clear(PyObject *module) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); Py_CLEAR(mod_state->ZstdDict_type); Py_CLEAR(mod_state->ZstdCompressor_type); |