diff options
Diffstat (limited to 'Modules/_zstd')
-rw-r--r-- | Modules/_zstd/_zstdmodule.c | 114 | ||||
-rw-r--r-- | Modules/_zstd/_zstdmodule.h | 14 | ||||
-rw-r--r-- | Modules/_zstd/clinic/compressor.c.h | 41 | ||||
-rw-r--r-- | Modules/_zstd/compressor.c | 338 | ||||
-rw-r--r-- | Modules/_zstd/decompressor.c | 99 | ||||
-rw-r--r-- | Modules/_zstd/zstddict.c | 1 |
6 files changed, 350 insertions, 257 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 56ad999e5cd..d75c0779474 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -7,7 +7,6 @@ #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include <zstd.h> // ZSTD_*() #include <zdict.h> // ZDICT_*() @@ -20,14 +19,52 @@ module _zstd #include "clinic/_zstdmodule.c.h" +ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) +{ + if (state == NULL) { + return NULL; + } + + /* Check ZstdDict */ + if (PyObject_TypeCheck(dict, state->ZstdDict_type)) { + return (ZstdDict*)dict; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2 + && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { + int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return NULL; + } + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) + { + *ptype = type; + return (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be a ZstdDict object."); + return NULL; +} + /* Format error message and set ZstdError. */ void -set_zstd_error(const _zstd_state* const state, - error_type type, size_t zstd_ret) +set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret) { - char *msg; + const char *msg; assert(ZSTD_isError(zstd_ret)); + if (state == NULL) { + return; + } switch (type) { case ERR_DECOMPRESS: msg = "Unable to decompress Zstandard data: %s"; @@ -35,6 +72,9 @@ set_zstd_error(const _zstd_state* const state, case ERR_COMPRESS: msg = "Unable to compress Zstandard data: %s"; break; + case ERR_SET_PLEDGED_INPUT_SIZE: + msg = "Unable to set pledged uncompressed content size: %s"; + break; case ERR_LOAD_D_DICT: msg = "Unable to load Zstandard dictionary or prefix for " @@ -103,16 +143,13 @@ static const ParameterInfo dp_list[] = { }; void -set_parameter_error(const _zstd_state* const state, int is_compress, - int key_v, int value_v) +set_parameter_error(int is_compress, int key_v, int value_v) { ParameterInfo const *list; int list_size; - char const *name; char *type; ZSTD_bounds bounds; - int i; - char pos_msg[128]; + char pos_msg[64]; if (is_compress) { list = cp_list; @@ -126,8 +163,8 @@ set_parameter_error(const _zstd_state* const state, int is_compress, } /* Find parameter's name */ - name = NULL; - for (i = 0; i < list_size; i++) { + char const *name = NULL; + for (int i = 0; i < list_size; i++) { if (key_v == (list+i)->parameter) { name = (list+i)->parameter_name; break; @@ -149,20 +186,16 @@ set_parameter_error(const _zstd_state* const state, int is_compress, bounds = ZSTD_dParam_getBounds(key_v); } if (ZSTD_isError(bounds.error)) { - PyErr_Format(state->ZstdError, - "Invalid zstd %s parameter \"%s\".", + PyErr_Format(PyExc_ValueError, "invalid %s parameter '%s'", type, name); return; } /* Error message */ - PyErr_Format(state->ZstdError, - "Error when setting zstd %s parameter \"%s\", it " - "should %d <= value <= %d, provided value is %d. " - "(%d-bit build)", - type, name, - bounds.lowerBound, bounds.upperBound, value_v, - 8*(int)sizeof(Py_ssize_t)); + PyErr_Format(PyExc_ValueError, + "%s parameter '%s' received an illegal value %d; " + "the valid range is [%d, %d]", + type, name, value_v, bounds.lowerBound, bounds.upperBound); } static inline _zstd_state* @@ -181,7 +214,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t sizes_sum; Py_ssize_t i; - chunks_number = Py_SIZE(samples_sizes); + chunks_number = PyTuple_GET_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, "The number of samples should be <= %u.", UINT32_MAX); @@ -195,20 +228,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, return -1; } - sizes_sum = 0; + sizes_sum = PyBytes_GET_SIZE(samples_bytes); for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - (*chunk_sizes)[i] = PyLong_AsSize_t(size); - if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); + size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i)); + (*chunk_sizes)[i] = size; + if (size == (size_t)-1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + goto sum_error; + } return -1; } - sizes_sum += (*chunk_sizes)[i]; + if ((size_t)sizes_sum < size) { + goto sum_error; + } + sizes_sum -= size; } - if (sizes_sum != Py_SIZE(samples_bytes)) { + if (sizes_sum != 0) { +sum_error: PyErr_SetString(PyExc_ValueError, "The samples size tuple doesn't match the " "concatenation's size."); @@ -264,7 +301,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Train the dictionary */ char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes); - char *samples_buffer = PyBytes_AS_STRING(samples_bytes); + const char *samples_buffer = PyBytes_AS_STRING(samples_bytes); Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size, samples_buffer, @@ -514,20 +551,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, { _zstd_state* mod_state = get_zstd_state(module); - if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { - PyErr_SetString(PyExc_ValueError, - "The two arguments should be CompressionParameter and " - "DecompressionParameter types."); - return NULL; - } - - Py_XDECREF(mod_state->CParameter_type); Py_INCREF(c_parameter_type); - mod_state->CParameter_type = (PyTypeObject*)c_parameter_type; - - Py_XDECREF(mod_state->DParameter_type); + Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type); Py_INCREF(d_parameter_type); - mod_state->DParameter_type = (PyTypeObject*)d_parameter_type; + Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type); Py_RETURN_NONE; } @@ -590,7 +617,6 @@ do { \ return -1; } if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) { - Py_DECREF(mod_state->ZstdError); return -1; } diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index b36486442c6..4e8f708f223 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -5,6 +5,8 @@ #ifndef ZSTD_MODULE_H #define ZSTD_MODULE_H +#include "zstddict.h" + /* Type specs */ extern PyType_Spec zstd_dict_type_spec; extern PyType_Spec zstd_compressor_type_spec; @@ -25,6 +27,7 @@ typedef struct { typedef enum { ERR_DECOMPRESS, ERR_COMPRESS, + ERR_SET_PLEDGED_INPUT_SIZE, ERR_LOAD_D_DICT, ERR_LOAD_C_DICT, @@ -43,13 +46,16 @@ typedef enum { DICT_TYPE_PREFIX = 2 } dictionary_type; +extern ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, + PyObject *dict, int *type); + /* Format error message and set ZstdError. */ extern void -set_zstd_error(const _zstd_state* const state, - const error_type type, size_t zstd_ret); +set_zstd_error(const _zstd_state *state, + error_type type, size_t zstd_ret); extern void -set_parameter_error(const _zstd_state* const state, int is_compress, - int key_v, int value_v); +set_parameter_error(int is_compress, int key_v, int value_v); #endif // !ZSTD_MODULE_H diff --git a/Modules/_zstd/clinic/compressor.c.h b/Modules/_zstd/clinic/compressor.c.h index f69161b590e..4f8d93fd9e8 100644 --- a/Modules/_zstd/clinic/compressor.c.h +++ b/Modules/_zstd/clinic/compressor.c.h @@ -252,4 +252,43 @@ skip_optional_pos: exit: return return_value; } -/*[clinic end generated code: output=ee2d1dc298de790c input=a9049054013a1b77]*/ + +PyDoc_STRVAR(_zstd_ZstdCompressor_set_pledged_input_size__doc__, +"set_pledged_input_size($self, size, /)\n" +"--\n" +"\n" +"Set the uncompressed content size to be written into the frame header.\n" +"\n" +" size\n" +" The size of the uncompressed data to be provided to the compressor.\n" +"\n" +"This method can be used to ensure the header of the frame about to be written\n" +"includes the size of the data, unless the CompressionParameter.content_size_flag\n" +"is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is raised.\n" +"\n" +"It is important to ensure that the pledged data size matches the actual data\n" +"size. If they do not match the compressed output data may be corrupted and the\n" +"final chunk written may be lost."); + +#define _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF \ + {"set_pledged_input_size", (PyCFunction)_zstd_ZstdCompressor_set_pledged_input_size, METH_O, _zstd_ZstdCompressor_set_pledged_input_size__doc__}, + +static PyObject * +_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self, + unsigned long long size); + +static PyObject * +_zstd_ZstdCompressor_set_pledged_input_size(PyObject *self, PyObject *arg) +{ + PyObject *return_value = NULL; + unsigned long long size; + + if (!zstd_contentsize_converter(arg, &size)) { + goto exit; + } + return_value = _zstd_ZstdCompressor_set_pledged_input_size_impl((ZstdCompressor *)self, size); + +exit: + return return_value; +} +/*[clinic end generated code: output=c1d5c2cf06a8becd input=a9049054013a1b77]*/ diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 7f0558909b4..bc9e6eff89a 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include <stddef.h> // offsetof() @@ -46,101 +45,152 @@ typedef struct { #define ZstdCompressor_CAST(op) ((ZstdCompressor *)op) +/*[python input] + +class zstd_contentsize_converter(CConverter): + type = 'unsigned long long' + converter = 'zstd_contentsize_converter' + +[python start generated code]*/ +/*[python end generated code: output=da39a3ee5e6b4b0d input=0932c350d633c7de]*/ + + +static int +zstd_contentsize_converter(PyObject *size, unsigned long long *p) +{ + // None means the user indicates the size is unknown. + if (size == Py_None) { + *p = ZSTD_CONTENTSIZE_UNKNOWN; + } + else { + /* ZSTD_CONTENTSIZE_UNKNOWN is 0ULL - 1 + ZSTD_CONTENTSIZE_ERROR is 0ULL - 2 + Users should only pass values < ZSTD_CONTENTSIZE_ERROR */ + unsigned long long pledged_size = PyLong_AsUnsignedLongLong(size); + /* Here we check for (unsigned long long)-1 as a sign of an error in + PyLong_AsUnsignedLongLong */ + if (pledged_size == (unsigned long long)-1 && PyErr_Occurred()) { + *p = ZSTD_CONTENTSIZE_ERROR; + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + PyErr_Format(PyExc_ValueError, + "size argument should be a positive int less " + "than %ull", ZSTD_CONTENTSIZE_ERROR); + return 0; + } + return 0; + } + if (pledged_size >= ZSTD_CONTENTSIZE_ERROR) { + *p = ZSTD_CONTENTSIZE_ERROR; + PyErr_Format(PyExc_ValueError, + "size argument should be a positive int less " + "than %ull", ZSTD_CONTENTSIZE_ERROR); + return 0; + } + *p = pledged_size; + } + return 1; +} + #include "clinic/compressor.c.h" static int -_zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, - const char *arg_name, const char* arg_type) +_zstd_set_c_level(ZstdCompressor *self, int level) +{ + /* Set integer compression level */ + int min_level = ZSTD_minCLevel(); + int max_level = ZSTD_maxCLevel(); + if (level < min_level || level > max_level) { + PyErr_Format(PyExc_ValueError, + "illegal compression level %d; the valid range is [%d, %d]", + level, min_level, max_level); + return -1; + } + + /* Save for generating ZSTD_CDICT */ + self->compression_level = level; + + /* Set compressionLevel to compression context */ + size_t zstd_ret = ZSTD_CCtx_setParameter( + self->cctx, ZSTD_c_compressionLevel, level); + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); + set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret); + return -1; + } + return 0; +} + +static int +_zstd_set_c_parameters(ZstdCompressor *self, PyObject *options) { - size_t zstd_ret; _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } - /* Integer compression level */ - if (PyLong_Check(level_or_options)) { - int level = PyLong_AsInt(level_or_options); - if (level == -1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Compression level should be an int value between " - "%d and %d.", ZSTD_minCLevel(), ZSTD_maxCLevel()); + if (!PyDict_Check(options)) { + PyErr_Format(PyExc_TypeError, + "ZstdCompressor() argument 'options' must be dict, not %T", + options); + return -1; + } + + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(options, &pos, &key, &value)) { + /* Check key type */ + if (Py_TYPE(key) == mod_state->DParameter_type) { + PyErr_SetString(PyExc_TypeError, + "compression options dictionary key must not be a " + "DecompressionParameter attribute"); return -1; } - /* Save for generating ZSTD_CDICT */ - self->compression_level = level; - - /* Set compressionLevel to compression context */ - zstd_ret = ZSTD_CCtx_setParameter(self->cctx, - ZSTD_c_compressionLevel, - level); - - /* Check error */ - if (ZSTD_isError(zstd_ret)) { - set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret); + Py_INCREF(key); + Py_INCREF(value); + int key_v = PyLong_AsInt(key); + Py_DECREF(key); + if (key_v == -1 && PyErr_Occurred()) { + Py_DECREF(value); return -1; } - return 0; - } - - /* Options dict */ - if (PyDict_Check(level_or_options)) { - PyObject *key, *value; - Py_ssize_t pos = 0; - while (PyDict_Next(level_or_options, &pos, &key, &value)) { - /* Check key type */ - if (Py_TYPE(key) == mod_state->DParameter_type) { - PyErr_SetString(PyExc_TypeError, - "Key of compression options dict should " - "NOT be a DecompressionParameter attribute."); - return -1; - } + int value_v = PyLong_AsInt(value); + Py_DECREF(value); + if (value_v == -1 && PyErr_Occurred()) { + return -1; + } - int key_v = PyLong_AsInt(key); - if (key_v == -1 && PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "Key of options dict should be either a " - "CompressionParameter attribute or an int."); + if (key_v == ZSTD_c_compressionLevel) { + if (_zstd_set_c_level(self, value_v) < 0) { return -1; } - - int value_v = PyLong_AsInt(value); - if (value_v == -1 && PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "Value of options dict should be an int."); - return -1; + continue; + } + if (key_v == ZSTD_c_nbWorkers) { + /* From the zstd library docs: + 1. When nbWorkers >= 1, triggers asynchronous mode when + used with ZSTD_compressStream2(). + 2, Default value is `0`, aka "single-threaded mode" : no + worker is spawned, compression is performed inside + caller's thread, all invocations are blocking. */ + if (value_v != 0) { + self->use_multithread = 1; } + } - if (key_v == ZSTD_c_compressionLevel) { - /* Save for generating ZSTD_CDICT */ - self->compression_level = value_v; - } - else if (key_v == ZSTD_c_nbWorkers) { - /* From the zstd library docs: - 1. When nbWorkers >= 1, triggers asynchronous mode when - used with ZSTD_compressStream2(). - 2, Default value is `0`, aka "single-threaded mode" : no - worker is spawned, compression is performed inside - caller's thread, all invocations are blocking. */ - if (value_v != 0) { - self->use_multithread = 1; - } - } + /* Set parameter to compression context */ + size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v); - /* Set parameter to compression context */ - zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v); - if (ZSTD_isError(zstd_ret)) { - set_parameter_error(mod_state, 1, key_v, value_v); - return -1; - } + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_parameter_error(1, key_v, value_v); + return -1; } - return 0; } - PyErr_Format(PyExc_TypeError, - "Invalid type for %s. Expected %s", arg_name, arg_type); - return -1; + return 0; } static void @@ -257,56 +307,17 @@ static int _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { - return -1; - } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { + /* When compressing, use undigested dictionary by default. */ + int type = DICT_TYPE_UNDIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - else if (ret > 0) { - /* When compressing, use undigested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_UNDIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /*[clinic input] @@ -354,20 +365,35 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level, self->last_mode = ZSTD_e_end; if (level != Py_None && options != Py_None) { - PyErr_SetString(PyExc_RuntimeError, + PyErr_SetString(PyExc_TypeError, "Only one of level or options should be used."); goto error; } - /* Set compressLevel/options to compression context */ + /* Set compression level */ if (level != Py_None) { - if (_zstd_set_c_parameters(self, level, "level", "int") < 0) { + if (!PyLong_Check(level)) { + PyErr_SetString(PyExc_TypeError, + "invalid type for level, expected int"); + goto error; + } + int level_v = PyLong_AsInt(level); + if (level_v == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + PyErr_Format(PyExc_ValueError, + "illegal compression level; the valid range is [%d, %d]", + ZSTD_minCLevel(), ZSTD_maxCLevel()); + } + goto error; + } + if (_zstd_set_c_level(self, level_v) < 0) { goto error; } } + /* Set options dictionary */ if (options != Py_None) { - if (_zstd_set_c_parameters(self, options, "options", "dict") < 0) { + if (_zstd_set_c_parameters(self, options) < 0) { goto error; } } @@ -458,9 +484,7 @@ compress_lock_held(ZstdCompressor *self, Py_buffer *data, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } @@ -489,7 +513,7 @@ error: return NULL; } -#ifdef Py_DEBUG +#ifndef NDEBUG static inline int mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out) { @@ -530,9 +554,7 @@ compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data) /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } @@ -667,9 +689,61 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) return ret; } + +/*[clinic input] +_zstd.ZstdCompressor.set_pledged_input_size + + size: zstd_contentsize + The size of the uncompressed data to be provided to the compressor. + / + +Set the uncompressed content size to be written into the frame header. + +This method can be used to ensure the header of the frame about to be written +includes the size of the data, unless the CompressionParameter.content_size_flag +is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is raised. + +It is important to ensure that the pledged data size matches the actual data +size. If they do not match the compressed output data may be corrupted and the +final chunk written may be lost. +[clinic start generated code]*/ + +static PyObject * +_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self, + unsigned long long size) +/*[clinic end generated code: output=3a09e55cc0e3b4f9 input=afd8a7d78cff2eb5]*/ +{ + // Error occured while converting argument, should be unreachable + assert(size != ZSTD_CONTENTSIZE_ERROR); + + /* Thread-safe code */ + PyMutex_Lock(&self->lock); + + /* Check the current mode */ + if (self->last_mode != ZSTD_e_end) { + PyErr_SetString(PyExc_ValueError, + "set_pledged_input_size() method must be called " + "when last_mode == FLUSH_FRAME"); + PyMutex_Unlock(&self->lock); + return NULL; + } + + /* Set pledged content size */ + size_t zstd_ret = ZSTD_CCtx_setPledgedSrcSize(self->cctx, size); + PyMutex_Unlock(&self->lock); + if (ZSTD_isError(zstd_ret)) { + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); + set_zstd_error(mod_state, ERR_SET_PLEDGED_INPUT_SIZE, zstd_ret); + return NULL; + } + + Py_RETURN_NONE; +} + static PyMethodDef ZstdCompressor_methods[] = { _ZSTD_ZSTDCOMPRESSOR_COMPRESS_METHODDEF _ZSTD_ZSTDCOMPRESSOR_FLUSH_METHODDEF + _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF {NULL, NULL} }; diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 015cb774ed2..c53d6e4cb05 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include <stdbool.h> // bool @@ -61,11 +60,6 @@ _get_DDict(ZstdDict *self) assert(PyMutex_IsLocked(&self->lock)); ZSTD_DDict *ret; - /* Already created */ - if (self->d_dict != NULL) { - return self->d_dict; - } - if (self->d_dict == NULL) { /* Create ZSTD_DDict instance from dictionary content */ Py_BEGIN_ALLOW_THREADS @@ -86,56 +80,52 @@ _get_DDict(ZstdDict *self) return self->d_dict; } -/* Set decompression parameters to decompression context */ static int _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) { - size_t zstd_ret; - PyObject *key, *value; - Py_ssize_t pos; _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } if (!PyDict_Check(options)) { - PyErr_SetString(PyExc_TypeError, - "options argument should be dict object."); + PyErr_Format(PyExc_TypeError, + "ZstdDecompressor() argument 'options' must be dict, not %T", + options); return -1; } - pos = 0; + Py_ssize_t pos = 0; + PyObject *key, *value; while (PyDict_Next(options, &pos, &key, &value)) { /* Check key type */ if (Py_TYPE(key) == mod_state->CParameter_type) { PyErr_SetString(PyExc_TypeError, - "Key of decompression options dict should " - "NOT be a CompressionParameter attribute."); + "compression options dictionary key must not be a " + "CompressionParameter attribute"); return -1; } - /* Both key & value should be 32-bit signed int */ + Py_INCREF(key); + Py_INCREF(value); int key_v = PyLong_AsInt(key); + Py_DECREF(key); if (key_v == -1 && PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "Key of options dict should be either a " - "DecompressionParameter attribute or an int."); return -1; } int value_v = PyLong_AsInt(value); + Py_DECREF(value); if (value_v == -1 && PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "Value of options dict should be an int."); return -1; } /* Set parameter to compression context */ - zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v); + size_t zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v); /* Check error */ if (ZSTD_isError(zstd_ret)) { - set_parameter_error(mod_state, 0, key_v, value_v); + set_parameter_error(0, key_v, value_v); return -1; } } @@ -186,56 +176,17 @@ static int _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { + /* When decompressing, use digested dictionary by default. */ + int type = DICT_TYPE_DIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* When decompressing, use digested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_DIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* @@ -286,9 +237,7 @@ decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); goto error; } @@ -577,7 +526,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict, self->dict = zstd_dict; } - /* Set option to decompression context */ + /* Set options dictionary */ if (options != Py_None) { if (_zstd_set_d_parameters(self, options) < 0) { goto error; diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c index afc58b42e89..14f74aaed46 100644 --- a/Modules/_zstd/zstddict.c +++ b/Modules/_zstd/zstddict.c @@ -15,7 +15,6 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec" #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include "clinic/zstddict.c.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked |