diff options
Diffstat (limited to 'Modules/_zstd/_zstdmodule.c')
-rw-r--r-- | Modules/_zstd/_zstdmodule.c | 228 |
1 files changed, 129 insertions, 99 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 17d3bff1e98..d75c0779474 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -7,7 +7,6 @@ #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include <zstd.h> // ZSTD_*() #include <zdict.h> // ZDICT_*() @@ -20,49 +19,91 @@ module _zstd #include "clinic/_zstdmodule.c.h" +ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) +{ + if (state == NULL) { + return NULL; + } + + /* Check ZstdDict */ + if (PyObject_TypeCheck(dict, state->ZstdDict_type)) { + return (ZstdDict*)dict; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2 + && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { + int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return NULL; + } + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) + { + *ptype = type; + return (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be a ZstdDict object."); + return NULL; +} + /* Format error message and set ZstdError. */ void -set_zstd_error(const _zstd_state* const state, - error_type type, size_t zstd_ret) +set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret) { - char *msg; + const char *msg; assert(ZSTD_isError(zstd_ret)); - switch (type) - { - case ERR_DECOMPRESS: - msg = "Unable to decompress Zstandard data: %s"; - break; - case ERR_COMPRESS: - msg = "Unable to compress Zstandard data: %s"; - break; - - case ERR_LOAD_D_DICT: - msg = "Unable to load Zstandard dictionary or prefix for decompression: %s"; - break; - case ERR_LOAD_C_DICT: - msg = "Unable to load Zstandard dictionary or prefix for compression: %s"; - break; - - case ERR_GET_C_BOUNDS: - msg = "Unable to get zstd compression parameter bounds: %s"; - break; - case ERR_GET_D_BOUNDS: - msg = "Unable to get zstd decompression parameter bounds: %s"; - break; - case ERR_SET_C_LEVEL: - msg = "Unable to set zstd compression level: %s"; - break; - - case ERR_TRAIN_DICT: - msg = "Unable to train the Zstandard dictionary: %s"; - break; - case ERR_FINALIZE_DICT: - msg = "Unable to finalize the Zstandard dictionary: %s"; - break; - - default: - Py_UNREACHABLE(); + if (state == NULL) { + return; + } + switch (type) { + case ERR_DECOMPRESS: + msg = "Unable to decompress Zstandard data: %s"; + break; + case ERR_COMPRESS: + msg = "Unable to compress Zstandard data: %s"; + break; + case ERR_SET_PLEDGED_INPUT_SIZE: + msg = "Unable to set pledged uncompressed content size: %s"; + break; + + case ERR_LOAD_D_DICT: + msg = "Unable to load Zstandard dictionary or prefix for " + "decompression: %s"; + break; + case ERR_LOAD_C_DICT: + msg = "Unable to load Zstandard dictionary or prefix for " + "compression: %s"; + break; + + case ERR_GET_C_BOUNDS: + msg = "Unable to get zstd compression parameter bounds: %s"; + break; + case ERR_GET_D_BOUNDS: + msg = "Unable to get zstd decompression parameter bounds: %s"; + break; + case ERR_SET_C_LEVEL: + msg = "Unable to set zstd compression level: %s"; + break; + + case ERR_TRAIN_DICT: + msg = "Unable to train the Zstandard dictionary: %s"; + break; + case ERR_FINALIZE_DICT: + msg = "Unable to finalize the Zstandard dictionary: %s"; + break; + + default: + Py_UNREACHABLE(); } PyErr_Format(state->ZstdError, msg, ZSTD_getErrorName(zstd_ret)); } @@ -102,16 +143,13 @@ static const ParameterInfo dp_list[] = { }; void -set_parameter_error(const _zstd_state* const state, int is_compress, - int key_v, int value_v) +set_parameter_error(int is_compress, int key_v, int value_v) { ParameterInfo const *list; int list_size; - char const *name; char *type; ZSTD_bounds bounds; - int i; - char pos_msg[128]; + char pos_msg[64]; if (is_compress) { list = cp_list; @@ -125,8 +163,8 @@ set_parameter_error(const _zstd_state* const state, int is_compress, } /* Find parameter's name */ - name = NULL; - for (i = 0; i < list_size; i++) { + char const *name = NULL; + for (int i = 0; i < list_size; i++) { if (key_v == (list+i)->parameter) { name = (list+i)->parameter_name; break; @@ -148,20 +186,16 @@ set_parameter_error(const _zstd_state* const state, int is_compress, bounds = ZSTD_dParam_getBounds(key_v); } if (ZSTD_isError(bounds.error)) { - PyErr_Format(state->ZstdError, - "Invalid zstd %s parameter \"%s\".", + PyErr_Format(PyExc_ValueError, "invalid %s parameter '%s'", type, name); return; } /* Error message */ - PyErr_Format(state->ZstdError, - "Error when setting zstd %s parameter \"%s\", it " - "should %d <= value <= %d, provided value is %d. " - "(%d-bit build)", - type, name, - bounds.lowerBound, bounds.upperBound, value_v, - 8*(int)sizeof(Py_ssize_t)); + PyErr_Format(PyExc_ValueError, + "%s parameter '%s' received an illegal value %d; " + "the valid range is [%d, %d]", + type, name, value_v, bounds.lowerBound, bounds.upperBound); } static inline _zstd_state* @@ -180,10 +214,10 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t sizes_sum; Py_ssize_t i; - chunks_number = Py_SIZE(samples_sizes); + chunks_number = PyTuple_GET_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, - "The number of samples should be <= %u.", UINT32_MAX); + "The number of samples should be <= %u.", UINT32_MAX); return -1; } @@ -194,22 +228,27 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, return -1; } - sizes_sum = 0; + sizes_sum = PyBytes_GET_SIZE(samples_bytes); for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - (*chunk_sizes)[i] = PyLong_AsSize_t(size); - if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); + size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i)); + (*chunk_sizes)[i] = size; + if (size == (size_t)-1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + goto sum_error; + } return -1; } - sizes_sum += (*chunk_sizes)[i]; + if ((size_t)sizes_sum < size) { + goto sum_error; + } + sizes_sum -= size; } - if (sizes_sum != Py_SIZE(samples_bytes)) { + if (sizes_sum != 0) { +sum_error: PyErr_SetString(PyExc_ValueError, - "The samples size tuple doesn't match the concatenation's size."); + "The samples size tuple doesn't match the " + "concatenation's size."); return -1; } return chunks_number; @@ -242,15 +281,15 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Check arguments */ if (dict_size <= 0) { - PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); + PyErr_SetString(PyExc_ValueError, + "dict_size argument should be positive number."); return NULL; } /* Check that the samples are valid and get their sizes */ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, &chunk_sizes); - if (chunks_number < 0) - { + if (chunks_number < 0) { goto error; } @@ -262,7 +301,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Train the dictionary */ char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes); - char *samples_buffer = PyBytes_AS_STRING(samples_bytes); + const char *samples_buffer = PyBytes_AS_STRING(samples_bytes); Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size, samples_buffer, @@ -271,7 +310,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Check Zstandard dict error */ if (ZDICT_isError(zstd_ret)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_TRAIN_DICT, zstd_ret); goto error; } @@ -324,15 +363,15 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, /* Check arguments */ if (dict_size <= 0) { - PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number."); + PyErr_SetString(PyExc_ValueError, + "dict_size argument should be positive number."); return NULL; } /* Check that the samples are valid and get their sizes */ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes, &chunk_sizes); - if (chunks_number < 0) - { + if (chunks_number < 0) { goto error; } @@ -355,14 +394,15 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes, Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_finalizeDictionary( PyBytes_AS_STRING(dst_dict_bytes), dict_size, - PyBytes_AS_STRING(custom_dict_bytes), Py_SIZE(custom_dict_bytes), + PyBytes_AS_STRING(custom_dict_bytes), + Py_SIZE(custom_dict_bytes), PyBytes_AS_STRING(samples_bytes), chunk_sizes, (uint32_t)chunks_number, params); Py_END_ALLOW_THREADS /* Check Zstandard dict error */ if (ZDICT_isError(zstd_ret)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_FINALIZE_DICT, zstd_ret); goto error; } @@ -402,7 +442,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress) if (is_compress) { bound = ZSTD_cParam_getBounds(parameter); if (ZSTD_isError(bound.error)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_GET_C_BOUNDS, bound.error); return NULL; } @@ -410,7 +450,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress) else { bound = ZSTD_dParam_getBounds(parameter); if (ZSTD_isError(bound.error)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); set_zstd_error(mod_state, ERR_GET_D_BOUNDS, bound.error); return NULL; } @@ -435,9 +475,10 @@ _zstd_get_frame_size_impl(PyObject *module, Py_buffer *frame_buffer) { size_t frame_size; - frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, frame_buffer->len); + frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, + frame_buffer->len); if (ZSTD_isError(frame_size)) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); PyErr_Format(mod_state->ZstdError, "Error when finding the compressed size of a Zstandard frame. " "Ensure the frame_buffer argument starts from the " @@ -473,7 +514,7 @@ _zstd_get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer) /* #define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1) #define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) */ if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); PyErr_SetString(mod_state->ZstdError, "Error when getting information from the header of " "a Zstandard frame. Ensure the frame_buffer argument " @@ -508,22 +549,12 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, PyObject *d_parameter_type) /*[clinic end generated code: output=f3313b1294f19502 input=75d7a953580fae5f]*/ { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); - if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { - PyErr_SetString(PyExc_ValueError, - "The two arguments should be CompressionParameter and " - "DecompressionParameter types."); - return NULL; - } - - Py_XDECREF(mod_state->CParameter_type); Py_INCREF(c_parameter_type); - mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; - - Py_XDECREF(mod_state->DParameter_type); + Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type); Py_INCREF(d_parameter_type); - mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; + Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type); Py_RETURN_NONE; } @@ -568,7 +599,7 @@ do { \ Py_DECREF(v); \ } while (0) - _zstd_state* const mod_state = get_zstd_state(m); + _zstd_state* mod_state = get_zstd_state(m); /* Reusable objects & variables */ mod_state->CParameter_type = NULL; @@ -586,7 +617,6 @@ do { \ return -1; } if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) { - Py_DECREF(mod_state->ZstdError); return -1; } @@ -674,7 +704,7 @@ do { \ static int _zstd_traverse(PyObject *module, visitproc visit, void *arg) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); Py_VISIT(mod_state->ZstdDict_type); Py_VISIT(mod_state->ZstdCompressor_type); @@ -691,7 +721,7 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg) static int _zstd_clear(PyObject *module) { - _zstd_state* const mod_state = get_zstd_state(module); + _zstd_state* mod_state = get_zstd_state(module); Py_CLEAR(mod_state->ZstdDict_type); Py_CLEAR(mod_state->ZstdCompressor_type); |