aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Modules/_zstd
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_zstd')
-rw-r--r--Modules/_zstd/_zstdmodule.c319
-rw-r--r--Modules/_zstd/_zstdmodule.h14
-rw-r--r--Modules/_zstd/buffer.h9
-rw-r--r--Modules/_zstd/clinic/compressor.c.h41
-rw-r--r--Modules/_zstd/clinic/decompressor.c.h11
-rw-r--r--Modules/_zstd/clinic/zstddict.c.h80
-rw-r--r--Modules/_zstd/compressor.c452
-rw-r--r--Modules/_zstd/decompressor.c206
-rw-r--r--Modules/_zstd/zstddict.c108
-rw-r--r--Modules/_zstd/zstddict.h9
10 files changed, 688 insertions, 561 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c
index 0294828aa10..d75c0779474 100644
--- a/Modules/_zstd/_zstdmodule.c
+++ b/Modules/_zstd/_zstdmodule.c
@@ -7,7 +7,6 @@
#include "Python.h"
#include "_zstdmodule.h"
-#include "zstddict.h"
#include <zstd.h> // ZSTD_*()
#include <zdict.h> // ZDICT_*()
@@ -20,49 +19,91 @@ module _zstd
#include "clinic/_zstdmodule.c.h"
+ZstdDict *
+_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
+{
+ if (state == NULL) {
+ return NULL;
+ }
+
+ /* Check ZstdDict */
+ if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
+ return (ZstdDict*)dict;
+ }
+
+ /* Check (ZstdDict, type) */
+ if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
+ && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
+ && PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
+ {
+ int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
+ if (type == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ if (type == DICT_TYPE_DIGESTED
+ || type == DICT_TYPE_UNDIGESTED
+ || type == DICT_TYPE_PREFIX)
+ {
+ *ptype = type;
+ return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
+ }
+ }
+
+ /* Wrong type */
+ PyErr_SetString(PyExc_TypeError,
+ "zstd_dict argument should be a ZstdDict object.");
+ return NULL;
+}
+
/* Format error message and set ZstdError. */
void
-set_zstd_error(const _zstd_state* const state,
- error_type type, size_t zstd_ret)
+set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
{
- char *msg;
+ const char *msg;
assert(ZSTD_isError(zstd_ret));
- switch (type)
- {
- case ERR_DECOMPRESS:
- msg = "Unable to decompress Zstandard data: %s";
- break;
- case ERR_COMPRESS:
- msg = "Unable to compress Zstandard data: %s";
- break;
-
- case ERR_LOAD_D_DICT:
- msg = "Unable to load Zstandard dictionary or prefix for decompression: %s";
- break;
- case ERR_LOAD_C_DICT:
- msg = "Unable to load Zstandard dictionary or prefix for compression: %s";
- break;
-
- case ERR_GET_C_BOUNDS:
- msg = "Unable to get zstd compression parameter bounds: %s";
- break;
- case ERR_GET_D_BOUNDS:
- msg = "Unable to get zstd decompression parameter bounds: %s";
- break;
- case ERR_SET_C_LEVEL:
- msg = "Unable to set zstd compression level: %s";
- break;
-
- case ERR_TRAIN_DICT:
- msg = "Unable to train the Zstandard dictionary: %s";
- break;
- case ERR_FINALIZE_DICT:
- msg = "Unable to finalize the Zstandard dictionary: %s";
- break;
-
- default:
- Py_UNREACHABLE();
+ if (state == NULL) {
+ return;
+ }
+ switch (type) {
+ case ERR_DECOMPRESS:
+ msg = "Unable to decompress Zstandard data: %s";
+ break;
+ case ERR_COMPRESS:
+ msg = "Unable to compress Zstandard data: %s";
+ break;
+ case ERR_SET_PLEDGED_INPUT_SIZE:
+ msg = "Unable to set pledged uncompressed content size: %s";
+ break;
+
+ case ERR_LOAD_D_DICT:
+ msg = "Unable to load Zstandard dictionary or prefix for "
+ "decompression: %s";
+ break;
+ case ERR_LOAD_C_DICT:
+ msg = "Unable to load Zstandard dictionary or prefix for "
+ "compression: %s";
+ break;
+
+ case ERR_GET_C_BOUNDS:
+ msg = "Unable to get zstd compression parameter bounds: %s";
+ break;
+ case ERR_GET_D_BOUNDS:
+ msg = "Unable to get zstd decompression parameter bounds: %s";
+ break;
+ case ERR_SET_C_LEVEL:
+ msg = "Unable to set zstd compression level: %s";
+ break;
+
+ case ERR_TRAIN_DICT:
+ msg = "Unable to train the Zstandard dictionary: %s";
+ break;
+ case ERR_FINALIZE_DICT:
+ msg = "Unable to finalize the Zstandard dictionary: %s";
+ break;
+
+ default:
+ Py_UNREACHABLE();
}
PyErr_Format(state->ZstdError, msg, ZSTD_getErrorName(zstd_ret));
}
@@ -102,16 +143,13 @@ static const ParameterInfo dp_list[] = {
};
void
-set_parameter_error(const _zstd_state* const state, int is_compress,
- int key_v, int value_v)
+set_parameter_error(int is_compress, int key_v, int value_v)
{
ParameterInfo const *list;
int list_size;
- char const *name;
char *type;
ZSTD_bounds bounds;
- int i;
- char pos_msg[128];
+ char pos_msg[64];
if (is_compress) {
list = cp_list;
@@ -125,8 +163,8 @@ set_parameter_error(const _zstd_state* const state, int is_compress,
}
/* Find parameter's name */
- name = NULL;
- for (i = 0; i < list_size; i++) {
+ char const *name = NULL;
+ for (int i = 0; i < list_size; i++) {
if (key_v == (list+i)->parameter) {
name = (list+i)->parameter_name;
break;
@@ -148,20 +186,16 @@ set_parameter_error(const _zstd_state* const state, int is_compress,
bounds = ZSTD_dParam_getBounds(key_v);
}
if (ZSTD_isError(bounds.error)) {
- PyErr_Format(state->ZstdError,
- "Invalid zstd %s parameter \"%s\".",
+ PyErr_Format(PyExc_ValueError, "invalid %s parameter '%s'",
type, name);
return;
}
/* Error message */
- PyErr_Format(state->ZstdError,
- "Error when setting zstd %s parameter \"%s\", it "
- "should %d <= value <= %d, provided value is %d. "
- "(%d-bit build)",
- type, name,
- bounds.lowerBound, bounds.upperBound, value_v,
- 8*(int)sizeof(Py_ssize_t));
+ PyErr_Format(PyExc_ValueError,
+ "%s parameter '%s' received an illegal value %d; "
+ "the valid range is [%d, %d]",
+ type, name, value_v, bounds.lowerBound, bounds.upperBound);
}
static inline _zstd_state*
@@ -172,6 +206,54 @@ get_zstd_state(PyObject *module)
return (_zstd_state *)state;
}
+static Py_ssize_t
+calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
+ size_t **chunk_sizes)
+{
+ Py_ssize_t chunks_number;
+ Py_ssize_t sizes_sum;
+ Py_ssize_t i;
+
+ chunks_number = PyTuple_GET_SIZE(samples_sizes);
+ if ((size_t) chunks_number > UINT32_MAX) {
+ PyErr_Format(PyExc_ValueError,
+ "The number of samples should be <= %u.", UINT32_MAX);
+ return -1;
+ }
+
+ /* Prepare chunk_sizes */
+ *chunk_sizes = PyMem_New(size_t, chunks_number);
+ if (*chunk_sizes == NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ sizes_sum = PyBytes_GET_SIZE(samples_bytes);
+ for (i = 0; i < chunks_number; i++) {
+ size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
+ (*chunk_sizes)[i] = size;
+ if (size == (size_t)-1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ goto sum_error;
+ }
+ return -1;
+ }
+ if ((size_t)sizes_sum < size) {
+ goto sum_error;
+ }
+ sizes_sum -= size;
+ }
+
+ if (sizes_sum != 0) {
+sum_error:
+ PyErr_SetString(PyExc_ValueError,
+ "The samples size tuple doesn't match the "
+ "concatenation's size.");
+ return -1;
+ }
+ return chunks_number;
+}
+
/*[clinic input]
_zstd.train_dict
@@ -192,51 +274,22 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
PyObject *samples_sizes, Py_ssize_t dict_size)
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
{
- // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
- // are pretty similar. We should see if we can refactor them to share that code.
- Py_ssize_t chunks_number;
- size_t *chunk_sizes = NULL;
PyObject *dst_dict_bytes = NULL;
+ size_t *chunk_sizes = NULL;
+ Py_ssize_t chunks_number;
size_t zstd_ret;
- Py_ssize_t sizes_sum;
- Py_ssize_t i;
/* Check arguments */
if (dict_size <= 0) {
- PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
- return NULL;
- }
-
- chunks_number = Py_SIZE(samples_sizes);
- if ((size_t) chunks_number > UINT32_MAX) {
- PyErr_Format(PyExc_ValueError,
- "The number of samples should be <= %u.", UINT32_MAX);
+ PyErr_SetString(PyExc_ValueError,
+ "dict_size argument should be positive number.");
return NULL;
}
- /* Prepare chunk_sizes */
- chunk_sizes = PyMem_New(size_t, chunks_number);
- if (chunk_sizes == NULL) {
- PyErr_NoMemory();
- goto error;
- }
-
- sizes_sum = 0;
- for (i = 0; i < chunks_number; i++) {
- PyObject *size = PyTuple_GetItem(samples_sizes, i);
- chunk_sizes[i] = PyLong_AsSize_t(size);
- if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Items in samples_sizes should be an int "
- "object, with a value between 0 and %u.", SIZE_MAX);
- goto error;
- }
- sizes_sum += chunk_sizes[i];
- }
-
- if (sizes_sum != Py_SIZE(samples_bytes)) {
- PyErr_SetString(PyExc_ValueError,
- "The samples size tuple doesn't match the concatenation's size.");
+ /* Check that the samples are valid and get their sizes */
+ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
+ &chunk_sizes);
+ if (chunks_number < 0) {
goto error;
}
@@ -248,7 +301,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
/* Train the dictionary */
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
- char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
+ const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
Py_BEGIN_ALLOW_THREADS
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
samples_buffer,
@@ -257,7 +310,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
/* Check Zstandard dict error */
if (ZDICT_isError(zstd_ret)) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
set_zstd_error(mod_state, ERR_TRAIN_DICT, zstd_ret);
goto error;
}
@@ -307,45 +360,18 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
PyObject *dst_dict_bytes = NULL;
size_t zstd_ret;
ZDICT_params_t params;
- Py_ssize_t sizes_sum;
- Py_ssize_t i;
/* Check arguments */
if (dict_size <= 0) {
- PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
- return NULL;
- }
-
- chunks_number = Py_SIZE(samples_sizes);
- if ((size_t) chunks_number > UINT32_MAX) {
- PyErr_Format(PyExc_ValueError,
- "The number of samples should be <= %u.", UINT32_MAX);
+ PyErr_SetString(PyExc_ValueError,
+ "dict_size argument should be positive number.");
return NULL;
}
- /* Prepare chunk_sizes */
- chunk_sizes = PyMem_New(size_t, chunks_number);
- if (chunk_sizes == NULL) {
- PyErr_NoMemory();
- goto error;
- }
-
- sizes_sum = 0;
- for (i = 0; i < chunks_number; i++) {
- PyObject *size = PyTuple_GetItem(samples_sizes, i);
- chunk_sizes[i] = PyLong_AsSize_t(size);
- if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Items in samples_sizes should be an int "
- "object, with a value between 0 and %u.", SIZE_MAX);
- goto error;
- }
- sizes_sum += chunk_sizes[i];
- }
-
- if (sizes_sum != Py_SIZE(samples_bytes)) {
- PyErr_SetString(PyExc_ValueError,
- "The samples size tuple doesn't match the concatenation's size.");
+ /* Check that the samples are valid and get their sizes */
+ chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
+ &chunk_sizes);
+ if (chunks_number < 0) {
goto error;
}
@@ -368,14 +394,15 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
Py_BEGIN_ALLOW_THREADS
zstd_ret = ZDICT_finalizeDictionary(
PyBytes_AS_STRING(dst_dict_bytes), dict_size,
- PyBytes_AS_STRING(custom_dict_bytes), Py_SIZE(custom_dict_bytes),
+ PyBytes_AS_STRING(custom_dict_bytes),
+ Py_SIZE(custom_dict_bytes),
PyBytes_AS_STRING(samples_bytes), chunk_sizes,
(uint32_t)chunks_number, params);
Py_END_ALLOW_THREADS
/* Check Zstandard dict error */
if (ZDICT_isError(zstd_ret)) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
set_zstd_error(mod_state, ERR_FINALIZE_DICT, zstd_ret);
goto error;
}
@@ -415,7 +442,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress)
if (is_compress) {
bound = ZSTD_cParam_getBounds(parameter);
if (ZSTD_isError(bound.error)) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
set_zstd_error(mod_state, ERR_GET_C_BOUNDS, bound.error);
return NULL;
}
@@ -423,7 +450,7 @@ _zstd_get_param_bounds_impl(PyObject *module, int parameter, int is_compress)
else {
bound = ZSTD_dParam_getBounds(parameter);
if (ZSTD_isError(bound.error)) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
set_zstd_error(mod_state, ERR_GET_D_BOUNDS, bound.error);
return NULL;
}
@@ -448,9 +475,10 @@ _zstd_get_frame_size_impl(PyObject *module, Py_buffer *frame_buffer)
{
size_t frame_size;
- frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf, frame_buffer->len);
+ frame_size = ZSTD_findFrameCompressedSize(frame_buffer->buf,
+ frame_buffer->len);
if (ZSTD_isError(frame_size)) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
PyErr_Format(mod_state->ZstdError,
"Error when finding the compressed size of a Zstandard frame. "
"Ensure the frame_buffer argument starts from the "
@@ -486,7 +514,7 @@ _zstd_get_frame_info_impl(PyObject *module, Py_buffer *frame_buffer)
/* #define ZSTD_CONTENTSIZE_UNKNOWN (0ULL - 1)
#define ZSTD_CONTENTSIZE_ERROR (0ULL - 2) */
if (decompressed_size == ZSTD_CONTENTSIZE_ERROR) {
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
PyErr_SetString(mod_state->ZstdError,
"Error when getting information from the header of "
"a Zstandard frame. Ensure the frame_buffer argument "
@@ -521,22 +549,12 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
PyObject *d_parameter_type)
/*[clinic end generated code: output=f3313b1294f19502 input=75d7a953580fae5f]*/
{
- _zstd_state* const mod_state = get_zstd_state(module);
-
- if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
- PyErr_SetString(PyExc_ValueError,
- "The two arguments should be CompressionParameter and "
- "DecompressionParameter types.");
- return NULL;
- }
+ _zstd_state* mod_state = get_zstd_state(module);
- Py_XDECREF(mod_state->CParameter_type);
Py_INCREF(c_parameter_type);
- mod_state->CParameter_type = (PyTypeObject*)c_parameter_type;
-
- Py_XDECREF(mod_state->DParameter_type);
+ Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
Py_INCREF(d_parameter_type);
- mod_state->DParameter_type = (PyTypeObject*)d_parameter_type;
+ Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
Py_RETURN_NONE;
}
@@ -581,7 +599,7 @@ do { \
Py_DECREF(v); \
} while (0)
- _zstd_state* const mod_state = get_zstd_state(m);
+ _zstd_state* mod_state = get_zstd_state(m);
/* Reusable objects & variables */
mod_state->CParameter_type = NULL;
@@ -599,7 +617,6 @@ do { \
return -1;
}
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
- Py_DECREF(mod_state->ZstdError);
return -1;
}
@@ -687,7 +704,7 @@ do { \
static int
_zstd_traverse(PyObject *module, visitproc visit, void *arg)
{
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
Py_VISIT(mod_state->ZstdDict_type);
Py_VISIT(mod_state->ZstdCompressor_type);
@@ -704,7 +721,7 @@ _zstd_traverse(PyObject *module, visitproc visit, void *arg)
static int
_zstd_clear(PyObject *module)
{
- _zstd_state* const mod_state = get_zstd_state(module);
+ _zstd_state* mod_state = get_zstd_state(module);
Py_CLEAR(mod_state->ZstdDict_type);
Py_CLEAR(mod_state->ZstdCompressor_type);
diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h
index b36486442c6..4e8f708f223 100644
--- a/Modules/_zstd/_zstdmodule.h
+++ b/Modules/_zstd/_zstdmodule.h
@@ -5,6 +5,8 @@
#ifndef ZSTD_MODULE_H
#define ZSTD_MODULE_H
+#include "zstddict.h"
+
/* Type specs */
extern PyType_Spec zstd_dict_type_spec;
extern PyType_Spec zstd_compressor_type_spec;
@@ -25,6 +27,7 @@ typedef struct {
typedef enum {
ERR_DECOMPRESS,
ERR_COMPRESS,
+ ERR_SET_PLEDGED_INPUT_SIZE,
ERR_LOAD_D_DICT,
ERR_LOAD_C_DICT,
@@ -43,13 +46,16 @@ typedef enum {
DICT_TYPE_PREFIX = 2
} dictionary_type;
+extern ZstdDict *
+_Py_parse_zstd_dict(const _zstd_state *state,
+ PyObject *dict, int *type);
+
/* Format error message and set ZstdError. */
extern void
-set_zstd_error(const _zstd_state* const state,
- const error_type type, size_t zstd_ret);
+set_zstd_error(const _zstd_state *state,
+ error_type type, size_t zstd_ret);
extern void
-set_parameter_error(const _zstd_state* const state, int is_compress,
- int key_v, int value_v);
+set_parameter_error(int is_compress, int key_v, int value_v);
#endif // !ZSTD_MODULE_H
diff --git a/Modules/_zstd/buffer.h b/Modules/_zstd/buffer.h
index bff3a81d8aa..4c885fa0d72 100644
--- a/Modules/_zstd/buffer.h
+++ b/Modules/_zstd/buffer.h
@@ -19,7 +19,8 @@ _OutputBuffer_InitAndGrow(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob,
/* Ensure .list was set to NULL */
assert(buffer->list == NULL);
- Py_ssize_t res = _BlocksOutputBuffer_InitAndGrow(buffer, max_length, &ob->dst);
+ Py_ssize_t res = _BlocksOutputBuffer_InitAndGrow(buffer, max_length,
+ &ob->dst);
if (res < 0) {
return -1;
}
@@ -34,8 +35,7 @@ _OutputBuffer_InitAndGrow(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob,
Return -1 on failure */
static inline int
_OutputBuffer_InitWithSize(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob,
- Py_ssize_t max_length,
- Py_ssize_t init_size)
+ Py_ssize_t max_length, Py_ssize_t init_size)
{
Py_ssize_t block_size;
@@ -50,7 +50,8 @@ _OutputBuffer_InitWithSize(_BlocksOutputBuffer *buffer, ZSTD_outBuffer *ob,
block_size = init_size;
}
- Py_ssize_t res = _BlocksOutputBuffer_InitWithSize(buffer, block_size, &ob->dst);
+ Py_ssize_t res = _BlocksOutputBuffer_InitWithSize(buffer, block_size,
+ &ob->dst);
if (res < 0) {
return -1;
}
diff --git a/Modules/_zstd/clinic/compressor.c.h b/Modules/_zstd/clinic/compressor.c.h
index f69161b590e..4f8d93fd9e8 100644
--- a/Modules/_zstd/clinic/compressor.c.h
+++ b/Modules/_zstd/clinic/compressor.c.h
@@ -252,4 +252,43 @@ skip_optional_pos:
exit:
return return_value;
}
-/*[clinic end generated code: output=ee2d1dc298de790c input=a9049054013a1b77]*/
+
+PyDoc_STRVAR(_zstd_ZstdCompressor_set_pledged_input_size__doc__,
+"set_pledged_input_size($self, size, /)\n"
+"--\n"
+"\n"
+"Set the uncompressed content size to be written into the frame header.\n"
+"\n"
+" size\n"
+" The size of the uncompressed data to be provided to the compressor.\n"
+"\n"
+"This method can be used to ensure the header of the frame about to be written\n"
+"includes the size of the data, unless the CompressionParameter.content_size_flag\n"
+"is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is raised.\n"
+"\n"
+"It is important to ensure that the pledged data size matches the actual data\n"
+"size. If they do not match the compressed output data may be corrupted and the\n"
+"final chunk written may be lost.");
+
+#define _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF \
+ {"set_pledged_input_size", (PyCFunction)_zstd_ZstdCompressor_set_pledged_input_size, METH_O, _zstd_ZstdCompressor_set_pledged_input_size__doc__},
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self,
+ unsigned long long size);
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size(PyObject *self, PyObject *arg)
+{
+ PyObject *return_value = NULL;
+ unsigned long long size;
+
+ if (!zstd_contentsize_converter(arg, &size)) {
+ goto exit;
+ }
+ return_value = _zstd_ZstdCompressor_set_pledged_input_size_impl((ZstdCompressor *)self, size);
+
+exit:
+ return return_value;
+}
+/*[clinic end generated code: output=c1d5c2cf06a8becd input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/clinic/decompressor.c.h b/Modules/_zstd/clinic/decompressor.c.h
index 4ecb19e9bde..c6fdae74ab0 100644
--- a/Modules/_zstd/clinic/decompressor.c.h
+++ b/Modules/_zstd/clinic/decompressor.c.h
@@ -7,7 +7,6 @@ preserve
# include "pycore_runtime.h" // _Py_ID()
#endif
#include "pycore_abstract.h" // _PyNumber_Index()
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(_zstd_ZstdDecompressor_new__doc__,
@@ -114,13 +113,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self);
static PyObject *
_zstd_ZstdDecompressor_unused_data_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDecompressor_unused_data_get_impl((ZstdDecompressor *)self);
}
PyDoc_STRVAR(_zstd_ZstdDecompressor_decompress__doc__,
@@ -227,4 +220,4 @@ exit:
return return_value;
}
-/*[clinic end generated code: output=7a4d278f9244e684 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=30c12ef047027ede input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/clinic/zstddict.c.h b/Modules/_zstd/clinic/zstddict.c.h
index 34e0e4b3ecf..79db85405d6 100644
--- a/Modules/_zstd/clinic/zstddict.c.h
+++ b/Modules/_zstd/clinic/zstddict.c.h
@@ -6,7 +6,6 @@ preserve
# include "pycore_gc.h" // PyGC_Head
# include "pycore_runtime.h" // _Py_ID()
#endif
-#include "pycore_critical_section.h"// Py_BEGIN_CRITICAL_SECTION()
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(_zstd_ZstdDict_new__doc__,
@@ -26,7 +25,7 @@ PyDoc_STRVAR(_zstd_ZstdDict_new__doc__,
"by multiple ZstdCompressor or ZstdDecompressor objects.");
static PyObject *
-_zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject *dict_content,
+_zstd_ZstdDict_new_impl(PyTypeObject *type, Py_buffer *dict_content,
int is_raw);
static PyObject *
@@ -64,7 +63,7 @@ _zstd_ZstdDict_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
PyObject * const *fastargs;
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1;
- PyObject *dict_content;
+ Py_buffer dict_content = {NULL, NULL};
int is_raw = 0;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser,
@@ -72,7 +71,9 @@ _zstd_ZstdDict_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
if (!fastargs) {
goto exit;
}
- dict_content = fastargs[0];
+ if (PyObject_GetBuffer(fastargs[0], &dict_content, PyBUF_SIMPLE) != 0) {
+ goto exit;
+ }
if (!noptargs) {
goto skip_optional_kwonly;
}
@@ -81,16 +82,49 @@ _zstd_ZstdDict_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
goto exit;
}
skip_optional_kwonly:
- return_value = _zstd_ZstdDict_new_impl(type, dict_content, is_raw);
+ return_value = _zstd_ZstdDict_new_impl(type, &dict_content, is_raw);
exit:
+ /* Cleanup for dict_content */
+ if (dict_content.obj) {
+ PyBuffer_Release(&dict_content);
+ }
+
return return_value;
}
+PyDoc_STRVAR(_zstd_ZstdDict_dict_content__doc__,
+"The content of a Zstandard dictionary, as a bytes object.");
+#if defined(_zstd_ZstdDict_dict_content_DOCSTR)
+# undef _zstd_ZstdDict_dict_content_DOCSTR
+#endif
+#define _zstd_ZstdDict_dict_content_DOCSTR _zstd_ZstdDict_dict_content__doc__
+
+#if !defined(_zstd_ZstdDict_dict_content_DOCSTR)
+# define _zstd_ZstdDict_dict_content_DOCSTR NULL
+#endif
+#if defined(_ZSTD_ZSTDDICT_DICT_CONTENT_GETSETDEF)
+# undef _ZSTD_ZSTDDICT_DICT_CONTENT_GETSETDEF
+# define _ZSTD_ZSTDDICT_DICT_CONTENT_GETSETDEF {"dict_content", (getter)_zstd_ZstdDict_dict_content_get, (setter)_zstd_ZstdDict_dict_content_set, _zstd_ZstdDict_dict_content_DOCSTR},
+#else
+# define _ZSTD_ZSTDDICT_DICT_CONTENT_GETSETDEF {"dict_content", (getter)_zstd_ZstdDict_dict_content_get, NULL, _zstd_ZstdDict_dict_content_DOCSTR},
+#endif
+
+static PyObject *
+_zstd_ZstdDict_dict_content_get_impl(ZstdDict *self);
+
+static PyObject *
+_zstd_ZstdDict_dict_content_get(PyObject *self, void *Py_UNUSED(context))
+{
+ return _zstd_ZstdDict_dict_content_get_impl((ZstdDict *)self);
+}
+
PyDoc_STRVAR(_zstd_ZstdDict_as_digested_dict__doc__,
"Load as a digested dictionary to compressor.\n"
"\n"
-"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digested_dict)\n"
+"Pass this attribute as zstd_dict argument:\n"
+"compress(dat, zstd_dict=zd.as_digested_dict)\n"
+"\n"
"1. Some advanced compression parameters of compressor may be overridden\n"
" by parameters of digested dictionary.\n"
"2. ZstdDict has a digested dictionaries cache for each compression level.\n"
@@ -118,19 +152,15 @@ _zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_digested_dict_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_digested_dict_get_impl((ZstdDict *)self);
}
PyDoc_STRVAR(_zstd_ZstdDict_as_undigested_dict__doc__,
"Load as an undigested dictionary to compressor.\n"
"\n"
-"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undigested_dict)\n"
+"Pass this attribute as zstd_dict argument:\n"
+"compress(dat, zstd_dict=zd.as_undigested_dict)\n"
+"\n"
"1. The advanced compression parameters of compressor will not be overridden.\n"
"2. Loading an undigested dictionary is costly. If load an undigested dictionary\n"
" multiple times, consider reusing a compressor object.\n"
@@ -156,19 +186,15 @@ _zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_undigested_dict_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_undigested_dict_get_impl((ZstdDict *)self);
}
PyDoc_STRVAR(_zstd_ZstdDict_as_prefix__doc__,
"Load as a prefix to compressor/decompressor.\n"
"\n"
-"Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix)\n"
+"Pass this attribute as zstd_dict argument:\n"
+"compress(dat, zstd_dict=zd.as_prefix)\n"
+"\n"
"1. Prefix is compatible with long distance matching, while dictionary is not.\n"
"2. It only works for the first frame, then the compressor/decompressor will\n"
" return to no prefix state.\n"
@@ -194,12 +220,6 @@ _zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self);
static PyObject *
_zstd_ZstdDict_as_prefix_get(PyObject *self, void *Py_UNUSED(context))
{
- PyObject *return_value = NULL;
-
- Py_BEGIN_CRITICAL_SECTION(self);
- return_value = _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
- Py_END_CRITICAL_SECTION();
-
- return return_value;
+ return _zstd_ZstdDict_as_prefix_get_impl((ZstdDict *)self);
}
-/*[clinic end generated code: output=bfb31c1187477afd input=a9049054013a1b77]*/
+/*[clinic end generated code: output=4696cbc722e5fdfc input=a9049054013a1b77]*/
diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c
index 38baee2be1e..bc9e6eff89a 100644
--- a/Modules/_zstd/compressor.c
+++ b/Modules/_zstd/compressor.c
@@ -16,7 +16,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
-#include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <stddef.h> // offsetof()
#include <zstd.h> // ZSTD_*()
@@ -38,105 +38,159 @@ typedef struct {
/* Compression level */
int compression_level;
+
+ /* Lock to protect the compression context */
+ PyMutex lock;
} ZstdCompressor;
#define ZstdCompressor_CAST(op) ((ZstdCompressor *)op)
+/*[python input]
+
+class zstd_contentsize_converter(CConverter):
+ type = 'unsigned long long'
+ converter = 'zstd_contentsize_converter'
+
+[python start generated code]*/
+/*[python end generated code: output=da39a3ee5e6b4b0d input=0932c350d633c7de]*/
+
+
+static int
+zstd_contentsize_converter(PyObject *size, unsigned long long *p)
+{
+ // None means the user indicates the size is unknown.
+ if (size == Py_None) {
+ *p = ZSTD_CONTENTSIZE_UNKNOWN;
+ }
+ else {
+ /* ZSTD_CONTENTSIZE_UNKNOWN is 0ULL - 1
+ ZSTD_CONTENTSIZE_ERROR is 0ULL - 2
+ Users should only pass values < ZSTD_CONTENTSIZE_ERROR */
+ unsigned long long pledged_size = PyLong_AsUnsignedLongLong(size);
+ /* Here we check for (unsigned long long)-1 as a sign of an error in
+ PyLong_AsUnsignedLongLong */
+ if (pledged_size == (unsigned long long)-1 && PyErr_Occurred()) {
+ *p = ZSTD_CONTENTSIZE_ERROR;
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ PyErr_Format(PyExc_ValueError,
+ "size argument should be a positive int less "
+ "than %ull", ZSTD_CONTENTSIZE_ERROR);
+ return 0;
+ }
+ return 0;
+ }
+ if (pledged_size >= ZSTD_CONTENTSIZE_ERROR) {
+ *p = ZSTD_CONTENTSIZE_ERROR;
+ PyErr_Format(PyExc_ValueError,
+ "size argument should be a positive int less "
+ "than %ull", ZSTD_CONTENTSIZE_ERROR);
+ return 0;
+ }
+ *p = pledged_size;
+ }
+ return 1;
+}
+
#include "clinic/compressor.c.h"
static int
-_zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options,
- const char *arg_name, const char* arg_type)
+_zstd_set_c_level(ZstdCompressor *self, int level)
{
- size_t zstd_ret;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state == NULL) {
+ /* Set integer compression level */
+ int min_level = ZSTD_minCLevel();
+ int max_level = ZSTD_maxCLevel();
+ if (level < min_level || level > max_level) {
+ PyErr_Format(PyExc_ValueError,
+ "illegal compression level %d; the valid range is [%d, %d]",
+ level, min_level, max_level);
return -1;
}
- /* Integer compression level */
- if (PyLong_Check(level_or_options)) {
- int level = PyLong_AsInt(level_or_options);
- if (level == -1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Compression level should be an int value between %d and %d.",
- ZSTD_minCLevel(), ZSTD_maxCLevel());
- return -1;
- }
+ /* Save for generating ZSTD_CDICT */
+ self->compression_level = level;
- /* Save for generating ZSTD_CDICT */
- self->compression_level = level;
+ /* Set compressionLevel to compression context */
+ size_t zstd_ret = ZSTD_CCtx_setParameter(
+ self->cctx, ZSTD_c_compressionLevel, level);
- /* Set compressionLevel to compression context */
- zstd_ret = ZSTD_CCtx_setParameter(self->cctx,
- ZSTD_c_compressionLevel,
- level);
+ /* Check error */
+ if (ZSTD_isError(zstd_ret)) {
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
+ return -1;
+ }
+ return 0;
+}
- /* Check error */
- if (ZSTD_isError(zstd_ret)) {
- set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret);
+static int
+_zstd_set_c_parameters(ZstdCompressor *self, PyObject *options)
+{
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ if (mod_state == NULL) {
+ return -1;
+ }
+
+ if (!PyDict_Check(options)) {
+ PyErr_Format(PyExc_TypeError,
+ "ZstdCompressor() argument 'options' must be dict, not %T",
+ options);
+ return -1;
+ }
+
+ Py_ssize_t pos = 0;
+ PyObject *key, *value;
+ while (PyDict_Next(options, &pos, &key, &value)) {
+ /* Check key type */
+ if (Py_TYPE(key) == mod_state->DParameter_type) {
+ PyErr_SetString(PyExc_TypeError,
+ "compression options dictionary key must not be a "
+ "DecompressionParameter attribute");
return -1;
}
- return 0;
- }
- /* Options dict */
- if (PyDict_Check(level_or_options)) {
- PyObject *key, *value;
- Py_ssize_t pos = 0;
+ Py_INCREF(key);
+ Py_INCREF(value);
+ int key_v = PyLong_AsInt(key);
+ Py_DECREF(key);
+ if (key_v == -1 && PyErr_Occurred()) {
+ Py_DECREF(value);
+ return -1;
+ }
- while (PyDict_Next(level_or_options, &pos, &key, &value)) {
- /* Check key type */
- if (Py_TYPE(key) == mod_state->DParameter_type) {
- PyErr_SetString(PyExc_TypeError,
- "Key of compression option dict should "
- "NOT be DecompressionParameter.");
- return -1;
- }
+ int value_v = PyLong_AsInt(value);
+ Py_DECREF(value);
+ if (value_v == -1 && PyErr_Occurred()) {
+ return -1;
+ }
- int key_v = PyLong_AsInt(key);
- if (key_v == -1 && PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "Key of options dict should be a CompressionParameter attribute.");
+ if (key_v == ZSTD_c_compressionLevel) {
+ if (_zstd_set_c_level(self, value_v) < 0) {
return -1;
}
-
- // TODO(emmatyping): check bounds when there is a value error here for better
- // error message?
- int value_v = PyLong_AsInt(value);
- if (value_v == -1 && PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "Value of option dict should be an int.");
- return -1;
+ continue;
+ }
+ if (key_v == ZSTD_c_nbWorkers) {
+ /* From the zstd library docs:
+ 1. When nbWorkers >= 1, triggers asynchronous mode when
+ used with ZSTD_compressStream2().
+ 2, Default value is `0`, aka "single-threaded mode" : no
+ worker is spawned, compression is performed inside
+ caller's thread, all invocations are blocking. */
+ if (value_v != 0) {
+ self->use_multithread = 1;
}
+ }
- if (key_v == ZSTD_c_compressionLevel) {
- /* Save for generating ZSTD_CDICT */
- self->compression_level = value_v;
- }
- else if (key_v == ZSTD_c_nbWorkers) {
- /* From the zstd library docs:
- 1. When nbWorkers >= 1, triggers asynchronous mode when
- used with ZSTD_compressStream2().
- 2, Default value is `0`, aka "single-threaded mode" : no
- worker is spawned, compression is performed inside
- caller's thread, all invocations are blocking. */
- if (value_v != 0) {
- self->use_multithread = 1;
- }
- }
+ /* Set parameter to compression context */
+ size_t zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
- /* Set parameter to compression context */
- zstd_ret = ZSTD_CCtx_setParameter(self->cctx, key_v, value_v);
- if (ZSTD_isError(zstd_ret)) {
- set_parameter_error(mod_state, 1, key_v, value_v);
- return -1;
- }
+ /* Check error */
+ if (ZSTD_isError(zstd_ret)) {
+ set_parameter_error(1, key_v, value_v);
+ return -1;
}
- return 0;
}
- PyErr_Format(PyExc_TypeError, "Invalid type for %s. Expected %s", arg_name, arg_type);
- return -1;
+ return 0;
}
static void
@@ -149,12 +203,12 @@ capsule_free_cdict(PyObject *capsule)
ZSTD_CDict *
_get_CDict(ZstdDict *self, int compressionLevel)
{
+ assert(PyMutex_IsLocked(&self->lock));
PyObject *level = NULL;
- PyObject *capsule;
+ PyObject *capsule = NULL;
ZSTD_CDict *cdict;
+ int ret;
- // TODO(emmatyping): refactor critical section code into a lock_held function
- Py_BEGIN_CRITICAL_SECTION(self);
/* int level object */
level = PyLong_FromLong(compressionLevel);
@@ -163,23 +217,19 @@ _get_CDict(ZstdDict *self, int compressionLevel)
}
/* Get PyCapsule object from self->c_dicts */
- capsule = PyDict_GetItemWithError(self->c_dicts, level);
+ ret = PyDict_GetItemRef(self->c_dicts, level, &capsule);
+ if (ret < 0) {
+ goto error;
+ }
if (capsule == NULL) {
- if (PyErr_Occurred()) {
- goto error;
- }
-
/* Create ZSTD_CDict instance */
- char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
- Py_ssize_t dict_len = Py_SIZE(self->dict_content);
Py_BEGIN_ALLOW_THREADS
- cdict = ZSTD_createCDict(dict_buffer,
- dict_len,
+ cdict = ZSTD_createCDict(self->dict_buffer, self->dict_len,
compressionLevel);
Py_END_ALLOW_THREADS
if (cdict == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Failed to create a ZSTD_CDict instance from "
@@ -196,11 +246,10 @@ _get_CDict(ZstdDict *self, int compressionLevel)
}
/* Add PyCapsule object to self->c_dicts */
- if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) {
- Py_DECREF(capsule);
+ ret = PyDict_SetItem(self->c_dicts, level, capsule);
+ if (ret < 0) {
goto error;
}
- Py_DECREF(capsule);
}
else {
/* ZSTD_CDict instance already exists */
@@ -212,62 +261,15 @@ error:
cdict = NULL;
success:
Py_XDECREF(level);
- Py_END_CRITICAL_SECTION();
+ Py_XDECREF(capsule);
return cdict;
}
static int
-_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd,
+ _zstd_state *mod_state, int type)
{
-
size_t zstd_ret;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state == NULL) {
- return -1;
- }
- ZstdDict *zd;
- int type, ret;
-
- /* Check ZstdDict */
- ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
- if (ret < 0) {
- return -1;
- }
- else if (ret > 0) {
- /* When compressing, use undigested dictionary by default. */
- zd = (ZstdDict*)dict;
- type = DICT_TYPE_UNDIGESTED;
- goto load;
- }
-
- /* Check (ZstdDict, type) */
- if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
- /* Check ZstdDict */
- ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
- (PyObject*)mod_state->ZstdDict_type);
- if (ret < 0) {
- return -1;
- }
- else if (ret > 0) {
- /* type == -1 may indicate an error. */
- type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
- if (type == DICT_TYPE_DIGESTED ||
- type == DICT_TYPE_UNDIGESTED ||
- type == DICT_TYPE_PREFIX)
- {
- assert(type >= 0);
- zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
- goto load;
- }
- }
- }
-
- /* Wrong type */
- PyErr_SetString(PyExc_TypeError,
- "zstd_dict argument should be ZstdDict object.");
- return -1;
-
-load:
if (type == DICT_TYPE_DIGESTED) {
/* Get ZSTD_CDict */
ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level);
@@ -276,28 +278,18 @@ load:
}
/* Reference a prepared dictionary.
It overrides some compression context's parameters. */
- Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict);
- Py_END_CRITICAL_SECTION();
}
else if (type == DICT_TYPE_UNDIGESTED) {
/* Load a dictionary.
It doesn't override compression context's parameters. */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_CCtx_loadDictionary(
- self->cctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
+ zstd_ret = ZSTD_CCtx_loadDictionary(self->cctx, zd->dict_buffer,
+ zd->dict_len);
}
else if (type == DICT_TYPE_PREFIX) {
/* Load a prefix */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_CCtx_refPrefix(
- self->cctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
+ zstd_ret = ZSTD_CCtx_refPrefix(self->cctx, zd->dict_buffer,
+ zd->dict_len);
}
else {
Py_UNREACHABLE();
@@ -311,6 +303,23 @@ load:
return 0;
}
+static int
+_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
+{
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ /* When compressing, use undigested dictionary by default. */
+ int type = DICT_TYPE_UNDIGESTED;
+ ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type);
+ if (zd == NULL) {
+ return -1;
+ }
+ int ret;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
+}
+
/*[clinic input]
@classmethod
_zstd.ZstdCompressor.__new__ as _zstd_ZstdCompressor_new
@@ -339,11 +348,12 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
self->use_multithread = 0;
self->dict = NULL;
+ self->lock = (PyMutex){0};
/* Compression context */
self->cctx = ZSTD_createCCtx();
if (self->cctx == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Unable to create ZSTD_CCtx instance.");
@@ -355,19 +365,35 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level,
self->last_mode = ZSTD_e_end;
if (level != Py_None && options != Py_None) {
- PyErr_SetString(PyExc_RuntimeError, "Only one of level or options should be used.");
+ PyErr_SetString(PyExc_TypeError,
+ "Only one of level or options should be used.");
goto error;
}
- /* Set compressLevel/options to compression context */
+ /* Set compression level */
if (level != Py_None) {
- if (_zstd_set_c_parameters(self, level, "level", "int") < 0) {
+ if (!PyLong_Check(level)) {
+ PyErr_SetString(PyExc_TypeError,
+ "invalid type for level, expected int");
+ goto error;
+ }
+ int level_v = PyLong_AsInt(level);
+ if (level_v == -1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ PyErr_Format(PyExc_ValueError,
+ "illegal compression level; the valid range is [%d, %d]",
+ ZSTD_minCLevel(), ZSTD_maxCLevel());
+ }
+ goto error;
+ }
+ if (_zstd_set_c_level(self, level_v) < 0) {
goto error;
}
}
+ /* Set options dictionary */
if (options != Py_None) {
- if (_zstd_set_c_parameters(self, options, "options", "dict") < 0) {
+ if (_zstd_set_c_parameters(self, options) < 0) {
goto error;
}
}
@@ -403,6 +429,8 @@ ZstdCompressor_dealloc(PyObject *ob)
ZSTD_freeCCtx(self->cctx);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Py_XDECREF the dict after free the compression context */
Py_CLEAR(self->dict);
@@ -412,9 +440,10 @@ ZstdCompressor_dealloc(PyObject *ob)
}
static PyObject *
-compress_impl(ZstdCompressor *self, Py_buffer *data,
- ZSTD_EndDirective end_directive)
+compress_lock_held(ZstdCompressor *self, Py_buffer *data,
+ ZSTD_EndDirective end_directive)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
ZSTD_outBuffer out;
_BlocksOutputBuffer buffer = {.list = NULL};
@@ -441,7 +470,7 @@ compress_impl(ZstdCompressor *self, Py_buffer *data,
}
if (_OutputBuffer_InitWithSize(&buffer, &out, -1,
- (Py_ssize_t) output_buffer_size) < 0) {
+ (Py_ssize_t) output_buffer_size) < 0) {
goto error;
}
@@ -454,10 +483,8 @@ compress_impl(ZstdCompressor *self, Py_buffer *data,
/* Check error */
if (ZSTD_isError(zstd_ret)) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state != NULL) {
- set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
- }
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
goto error;
}
@@ -486,7 +513,7 @@ error:
return NULL;
}
-#ifdef Py_DEBUG
+#ifndef NDEBUG
static inline int
mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out)
{
@@ -495,8 +522,9 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out)
#endif
static PyObject *
-compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
+compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
ZSTD_outBuffer out;
_BlocksOutputBuffer buffer = {.list = NULL};
@@ -516,20 +544,21 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data)
while (1) {
Py_BEGIN_ALLOW_THREADS
do {
- zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in, ZSTD_e_continue);
- } while (out.pos != out.size && in.pos != in.size && !ZSTD_isError(zstd_ret));
+ zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in,
+ ZSTD_e_continue);
+ } while (out.pos != out.size
+ && in.pos != in.size
+ && !ZSTD_isError(zstd_ret));
Py_END_ALLOW_THREADS
/* Check error */
if (ZSTD_isError(zstd_ret)) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state != NULL) {
- set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
- }
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret);
goto error;
}
- /* Like compress_impl(), output as much as possible. */
+ /* Like compress_lock_held(), output as much as possible. */
if (out.pos == out.size) {
if (_OutputBuffer_Grow(&buffer, &out) < 0) {
goto error;
@@ -588,14 +617,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
}
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
+ PyMutex_Lock(&self->lock);
/* Compress */
if (self->use_multithread && mode == ZSTD_e_continue) {
- ret = compress_mt_continue_impl(self, data);
+ ret = compress_mt_continue_lock_held(self, data);
}
else {
- ret = compress_impl(self, data, mode);
+ ret = compress_lock_held(self, data, mode);
}
if (ret) {
@@ -607,7 +636,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data,
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
- Py_END_CRITICAL_SECTION();
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -642,8 +671,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
}
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
- ret = compress_impl(self, NULL, mode);
+ PyMutex_Lock(&self->lock);
+
+ ret = compress_lock_held(self, NULL, mode);
if (ret) {
self->last_mode = mode;
@@ -654,26 +684,78 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode)
/* Resetting cctx's session never fail */
ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only);
}
- Py_END_CRITICAL_SECTION();
+ PyMutex_Unlock(&self->lock);
return ret;
}
+
+/*[clinic input]
+_zstd.ZstdCompressor.set_pledged_input_size
+
+ size: zstd_contentsize
+ The size of the uncompressed data to be provided to the compressor.
+ /
+
+Set the uncompressed content size to be written into the frame header.
+
+This method can be used to ensure the header of the frame about to be written
+includes the size of the data, unless the CompressionParameter.content_size_flag
+is set to False. If last_mode != FLUSH_FRAME, then a RuntimeError is raised.
+
+It is important to ensure that the pledged data size matches the actual data
+size. If they do not match the compressed output data may be corrupted and the
+final chunk written may be lost.
+[clinic start generated code]*/
+
+static PyObject *
+_zstd_ZstdCompressor_set_pledged_input_size_impl(ZstdCompressor *self,
+ unsigned long long size)
+/*[clinic end generated code: output=3a09e55cc0e3b4f9 input=afd8a7d78cff2eb5]*/
+{
+ // Error occured while converting argument, should be unreachable
+ assert(size != ZSTD_CONTENTSIZE_ERROR);
+
+ /* Thread-safe code */
+ PyMutex_Lock(&self->lock);
+
+ /* Check the current mode */
+ if (self->last_mode != ZSTD_e_end) {
+ PyErr_SetString(PyExc_ValueError,
+ "set_pledged_input_size() method must be called "
+ "when last_mode == FLUSH_FRAME");
+ PyMutex_Unlock(&self->lock);
+ return NULL;
+ }
+
+ /* Set pledged content size */
+ size_t zstd_ret = ZSTD_CCtx_setPledgedSrcSize(self->cctx, size);
+ PyMutex_Unlock(&self->lock);
+ if (ZSTD_isError(zstd_ret)) {
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ set_zstd_error(mod_state, ERR_SET_PLEDGED_INPUT_SIZE, zstd_ret);
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
+}
+
static PyMethodDef ZstdCompressor_methods[] = {
_ZSTD_ZSTDCOMPRESSOR_COMPRESS_METHODDEF
_ZSTD_ZSTDCOMPRESSOR_FLUSH_METHODDEF
+ _ZSTD_ZSTDCOMPRESSOR_SET_PLEDGED_INPUT_SIZE_METHODDEF
{NULL, NULL}
};
PyDoc_STRVAR(ZstdCompressor_last_mode_doc,
"The last mode used to this compressor object, its value can be .CONTINUE,\n"
".FLUSH_BLOCK, .FLUSH_FRAME. Initialized to .FLUSH_FRAME.\n\n"
-"It can be used to get the current state of a compressor, such as, data flushed,\n"
-"a frame ended.");
+"It can be used to get the current state of a compressor, such as, data\n"
+"flushed, or a frame ended.");
static PyMemberDef ZstdCompressor_members[] = {
{"last_mode", Py_T_INT, offsetof(ZstdCompressor, last_mode),
- Py_READONLY, ZstdCompressor_last_mode_doc},
+ Py_READONLY, ZstdCompressor_last_mode_doc},
{NULL}
};
diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c
index 58f9c9f804e..c53d6e4cb05 100644
--- a/Modules/_zstd/decompressor.c
+++ b/Modules/_zstd/decompressor.c
@@ -16,7 +16,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
-#include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <stdbool.h> // bool
#include <stddef.h> // offsetof()
@@ -45,6 +45,9 @@ typedef struct {
/* For ZstdDecompressor, 0 or 1.
1 means the end of the first frame has been reached. */
bool eof;
+
+ /* Lock to protect the decompression context */
+ PyMutex lock;
} ZstdDecompressor;
#define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op)
@@ -54,25 +57,18 @@ typedef struct {
static inline ZSTD_DDict *
_get_DDict(ZstdDict *self)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_DDict *ret;
- /* Already created */
- if (self->d_dict != NULL) {
- return self->d_dict;
- }
-
- Py_BEGIN_CRITICAL_SECTION(self);
if (self->d_dict == NULL) {
/* Create ZSTD_DDict instance from dictionary content */
- char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
- Py_ssize_t dict_len = Py_SIZE(self->dict_content);
Py_BEGIN_ALLOW_THREADS
- self->d_dict = ZSTD_createDDict(dict_buffer,
- dict_len);
+ ret = ZSTD_createDDict(self->dict_buffer, self->dict_len);
Py_END_ALLOW_THREADS
+ self->d_dict = ret;
if (self->d_dict == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Failed to create a ZSTD_DDict instance from "
@@ -81,124 +77,66 @@ _get_DDict(ZstdDict *self)
}
}
- /* Don't lose any exception */
- ret = self->d_dict;
- Py_END_CRITICAL_SECTION();
-
- return ret;
+ return self->d_dict;
}
-/* Set decompression parameters to decompression context */
static int
_zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
{
- size_t zstd_ret;
- PyObject *key, *value;
- Py_ssize_t pos;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}
if (!PyDict_Check(options)) {
- PyErr_SetString(PyExc_TypeError,
- "options argument should be dict object.");
+ PyErr_Format(PyExc_TypeError,
+ "ZstdDecompressor() argument 'options' must be dict, not %T",
+ options);
return -1;
}
- pos = 0;
+ Py_ssize_t pos = 0;
+ PyObject *key, *value;
while (PyDict_Next(options, &pos, &key, &value)) {
/* Check key type */
if (Py_TYPE(key) == mod_state->CParameter_type) {
PyErr_SetString(PyExc_TypeError,
- "Key of decompression options dict should "
- "NOT be CompressionParameter.");
+ "compression options dictionary key must not be a "
+ "CompressionParameter attribute");
return -1;
}
- /* Both key & value should be 32-bit signed int */
+ Py_INCREF(key);
+ Py_INCREF(value);
int key_v = PyLong_AsInt(key);
+ Py_DECREF(key);
if (key_v == -1 && PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "Key of options dict should be a DecompressionParameter attribute.");
return -1;
}
- // TODO(emmatyping): check bounds when there is a value error here for better
- // error message?
int value_v = PyLong_AsInt(value);
+ Py_DECREF(value);
if (value_v == -1 && PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "Value of options dict should be an int.");
return -1;
}
/* Set parameter to compression context */
- Py_BEGIN_CRITICAL_SECTION(self);
- zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
- Py_END_CRITICAL_SECTION();
+ size_t zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
/* Check error */
if (ZSTD_isError(zstd_ret)) {
- set_parameter_error(mod_state, 0, key_v, value_v);
+ set_parameter_error(0, key_v, value_v);
return -1;
}
}
return 0;
}
-/* Load dictionary or prefix to decompression context */
static int
-_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
+_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd,
+ _zstd_state *mod_state, int type)
{
size_t zstd_ret;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state == NULL) {
- return -1;
- }
- ZstdDict *zd;
- int type, ret;
-
- /* Check ZstdDict */
- ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type);
- if (ret < 0) {
- return -1;
- }
- else if (ret > 0) {
- /* When decompressing, use digested dictionary by default. */
- zd = (ZstdDict*)dict;
- type = DICT_TYPE_DIGESTED;
- goto load;
- }
-
- /* Check (ZstdDict, type) */
- if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
- /* Check ZstdDict */
- ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0),
- (PyObject*)mod_state->ZstdDict_type);
- if (ret < 0) {
- return -1;
- }
- else if (ret > 0) {
- /* type == -1 may indicate an error. */
- type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
- if (type == DICT_TYPE_DIGESTED ||
- type == DICT_TYPE_UNDIGESTED ||
- type == DICT_TYPE_PREFIX)
- {
- assert(type >= 0);
- zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
- goto load;
- }
- }
- }
-
- /* Wrong type */
- PyErr_SetString(PyExc_TypeError,
- "zstd_dict argument should be ZstdDict object.");
- return -1;
-
-load:
if (type == DICT_TYPE_DIGESTED) {
/* Get ZSTD_DDict */
ZSTD_DDict *d_dict = _get_DDict(zd);
@@ -206,27 +144,17 @@ load:
return -1;
}
/* Reference a prepared dictionary */
- Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
- Py_END_CRITICAL_SECTION();
}
else if (type == DICT_TYPE_UNDIGESTED) {
/* Load a dictionary */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_loadDictionary(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
+ zstd_ret = ZSTD_DCtx_loadDictionary(self->dctx, zd->dict_buffer,
+ zd->dict_len);
}
else if (type == DICT_TYPE_PREFIX) {
/* Load a prefix */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_refPrefix(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
+ zstd_ret = ZSTD_DCtx_refPrefix(self->dctx, zd->dict_buffer,
+ zd->dict_len);
}
else {
/* Impossible code path */
@@ -243,6 +171,24 @@ load:
return 0;
}
+/* Load dictionary or prefix to decompression context */
+static int
+_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
+{
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ /* When decompressing, use digested dictionary by default. */
+ int type = DICT_TYPE_DIGESTED;
+ ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type);
+ if (zd == NULL) {
+ return -1;
+ }
+ int ret;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
+}
+
/*
Decompress implementation in pseudo code:
@@ -268,8 +214,8 @@ load:
Note, decompressing "an empty input" in any case will make it > 0.
*/
static PyObject *
-decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
- Py_ssize_t max_length)
+decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
+ Py_ssize_t max_length)
{
size_t zstd_ret;
ZSTD_outBuffer out;
@@ -290,10 +236,8 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
/* Check error */
if (ZSTD_isError(zstd_ret)) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
- if (mod_state != NULL) {
- set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
- }
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
+ set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
goto error;
}
@@ -339,10 +283,9 @@ error:
}
static void
-decompressor_reset_session(ZstdDecompressor *self)
+decompressor_reset_session_lock_held(ZstdDecompressor *self)
{
- // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
- // and ensure lock is always held
+ assert(PyMutex_IsLocked(&self->lock));
/* Reset variables */
self->in_begin = 0;
@@ -359,15 +302,18 @@ decompressor_reset_session(ZstdDecompressor *self)
}
static PyObject *
-stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length)
+stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data,
+ Py_ssize_t max_length)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
PyObject *ret = NULL;
int use_input_buffer;
/* Check .eof flag */
if (self->eof) {
- PyErr_SetString(PyExc_EOFError, "Already at the end of a Zstandard frame.");
+ PyErr_SetString(PyExc_EOFError,
+ "Already at the end of a Zstandard frame.");
assert(ret == NULL);
return NULL;
}
@@ -456,7 +402,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
assert(in.pos == 0);
/* Decompress */
- ret = decompress_impl(self, &in, max_length);
+ ret = decompress_lock_held(self, &in, max_length);
if (ret == NULL) {
goto error;
}
@@ -484,8 +430,8 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
if (!use_input_buffer) {
/* Discard buffer if it's too small
(resizing it may needlessly copy the current contents) */
- if (self->input_buffer != NULL &&
- self->input_buffer_size < data_size)
+ if (self->input_buffer != NULL
+ && self->input_buffer_size < data_size)
{
PyMem_Free(self->input_buffer);
self->input_buffer = NULL;
@@ -517,7 +463,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
error:
/* Reset decompressor's states/session */
- decompressor_reset_session(self);
+ decompressor_reset_session_lock_held(self);
Py_CLEAR(ret);
return NULL;
@@ -555,6 +501,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
self->unused_data = NULL;
self->eof = 0;
self->dict = NULL;
+ self->lock = (PyMutex){0};
/* needs_input flag */
self->needs_input = 1;
@@ -562,7 +509,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
/* Decompression context */
self->dctx = ZSTD_createDCtx();
if (self->dctx == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Unable to create ZSTD_DCtx instance.");
@@ -579,7 +526,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
self->dict = zstd_dict;
}
- /* Set option to decompression context */
+ /* Set options dictionary */
if (options != Py_None) {
if (_zstd_set_d_parameters(self, options) < 0) {
goto error;
@@ -608,6 +555,8 @@ ZstdDecompressor_dealloc(PyObject *ob)
ZSTD_freeDCtx(self->dctx);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Py_CLEAR the dict after free decompression context */
Py_CLEAR(self->dict);
@@ -623,7 +572,6 @@ ZstdDecompressor_dealloc(PyObject *ob)
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDecompressor.unused_data
@@ -635,11 +583,14 @@ decompressed, unused input data after the frame. Otherwise this will be b''.
static PyObject *
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
-/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/
+/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/
{
PyObject *ret;
+ PyMutex_Lock(&self->lock);
+
if (!self->eof) {
+ PyMutex_Unlock(&self->lock);
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
}
else {
@@ -656,6 +607,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
}
}
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -693,10 +645,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
{
PyObject *ret;
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
-
- ret = stream_decompress(self, data, max_length);
- Py_END_CRITICAL_SECTION();
+ PyMutex_Lock(&self->lock);
+ ret = stream_decompress_lock_held(self, data, max_length);
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -710,9 +661,10 @@ PyDoc_STRVAR(ZstdDecompressor_eof_doc,
"after that, an EOFError exception will be raised.");
PyDoc_STRVAR(ZstdDecompressor_needs_input_doc,
-"If the max_length output limit in .decompress() method has been reached, and\n"
-"the decompressor has (or may has) unconsumed input data, it will be set to\n"
-"False. In this case, pass b'' to .decompress() method may output further data.");
+"If the max_length output limit in .decompress() method has been reached,\n"
+"and the decompressor has (or may has) unconsumed input data, it will be set\n"
+"to False. In this case, passing b'' to the .decompress() method may output\n"
+"further data.");
static PyMemberDef ZstdDecompressor_members[] = {
{"eof", Py_T_BOOL, offsetof(ZstdDecompressor, eof),
diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c
index 7df187a6fa6..14f74aaed46 100644
--- a/Modules/_zstd/zstddict.c
+++ b/Modules/_zstd/zstddict.c
@@ -15,8 +15,8 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec"
#include "Python.h"
#include "_zstdmodule.h"
-#include "zstddict.h"
#include "clinic/zstddict.c.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <zstd.h> // ZSTD_freeDDict(), ZSTD_getDictID_fromDict()
@@ -25,7 +25,7 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec"
/*[clinic input]
@classmethod
_zstd.ZstdDict.__new__ as _zstd_ZstdDict_new
- dict_content: object
+ dict_content: Py_buffer
The content of a Zstandard dictionary as a bytes-like object.
/
*
@@ -41,18 +41,27 @@ by multiple ZstdCompressor or ZstdDecompressor objects.
[clinic start generated code]*/
static PyObject *
-_zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject *dict_content,
+_zstd_ZstdDict_new_impl(PyTypeObject *type, Py_buffer *dict_content,
int is_raw)
-/*[clinic end generated code: output=3ebff839cb3be6d7 input=6b5de413869ae878]*/
+/*[clinic end generated code: output=685b7406a48b0949 input=9e8c493e31c98383]*/
{
+ /* All dictionaries must be at least 8 bytes */
+ if (dict_content->len < 8) {
+ PyErr_SetString(PyExc_ValueError,
+ "Zstandard dictionary content too short "
+ "(must have at least eight bytes)");
+ return NULL;
+ }
+
ZstdDict* self = PyObject_GC_New(ZstdDict, type);
if (self == NULL) {
- goto error;
+ return NULL;
}
- self->dict_content = NULL;
self->d_dict = NULL;
+ self->dict_buffer = NULL;
self->dict_id = 0;
+ self->lock = (PyMutex){0};
/* ZSTD_CDict dict */
self->c_dicts = PyDict_New();
@@ -60,37 +69,26 @@ _zstd_ZstdDict_new_impl(PyTypeObject *type, PyObject *dict_content,
goto error;
}
- /* Check dict_content's type */
- self->dict_content = PyBytes_FromObject(dict_content);
- if (self->dict_content == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "dict_content argument should be bytes-like object.");
- goto error;
- }
-
- /* Both ordinary dictionary and "raw content" dictionary should
- at least 8 bytes */
- if (Py_SIZE(self->dict_content) < 8) {
- PyErr_SetString(PyExc_ValueError,
- "Zstandard dictionary content should at least 8 bytes.");
+ self->dict_buffer = PyMem_Malloc(dict_content->len);
+ if (!self->dict_buffer) {
+ PyErr_NoMemory();
goto error;
}
+ memcpy(self->dict_buffer, dict_content->buf, dict_content->len);
+ self->dict_len = dict_content->len;
/* Get dict_id, 0 means "raw content" dictionary. */
- self->dict_id = ZSTD_getDictID_fromDict(PyBytes_AS_STRING(self->dict_content),
- Py_SIZE(self->dict_content));
+ self->dict_id = ZSTD_getDictID_fromDict(self->dict_buffer, self->dict_len);
/* Check validity for ordinary dictionary */
if (!is_raw && self->dict_id == 0) {
- char *msg = "Invalid Zstandard dictionary and is_raw not set.\n";
- PyErr_SetString(PyExc_ValueError, msg);
+ PyErr_SetString(PyExc_ValueError, "invalid Zstandard dictionary");
goto error;
}
- // Can only track self once self->dict_content is included
PyObject_GC_Track(self);
- return (PyObject*)self;
+ return (PyObject *)self;
error:
Py_XDECREF(self);
@@ -109,12 +107,14 @@ ZstdDict_dealloc(PyObject *ob)
ZSTD_freeDDict(self->d_dict);
}
- /* Release dict_content after Free ZSTD_CDict/ZSTD_DDict instances */
- Py_CLEAR(self->dict_content);
+ assert(!PyMutex_IsLocked(&self->lock));
+
+ /* Release dict_buffer after freeing ZSTD_CDict/ZSTD_DDict instances */
+ PyMem_Free(self->dict_buffer);
Py_CLEAR(self->c_dicts);
PyTypeObject *tp = Py_TYPE(self);
- PyObject_GC_Del(ob);
+ tp->tp_free(self);
Py_DECREF(tp);
}
@@ -125,31 +125,42 @@ PyDoc_STRVAR(ZstdDict_dictid_doc,
"The special value '0' means a 'raw content' dictionary,"
"without any restrictions on format or content.");
-PyDoc_STRVAR(ZstdDict_dictcontent_doc,
-"The content of a Zstandard dictionary, as a bytes object.");
-
static PyObject *
-ZstdDict_str(PyObject *ob)
+ZstdDict_repr(PyObject *ob)
{
ZstdDict *dict = ZstdDict_CAST(ob);
return PyUnicode_FromFormat("<ZstdDict dict_id=%u dict_size=%zd>",
- dict->dict_id, Py_SIZE(dict->dict_content));
+ (unsigned int)dict->dict_id, dict->dict_len);
}
static PyMemberDef ZstdDict_members[] = {
{"dict_id", Py_T_UINT, offsetof(ZstdDict, dict_id), Py_READONLY, ZstdDict_dictid_doc},
- {"dict_content", Py_T_OBJECT_EX, offsetof(ZstdDict, dict_content), Py_READONLY, ZstdDict_dictcontent_doc},
{NULL}
};
/*[clinic input]
-@critical_section
+@getter
+_zstd.ZstdDict.dict_content
+
+The content of a Zstandard dictionary, as a bytes object.
+[clinic start generated code]*/
+
+static PyObject *
+_zstd_ZstdDict_dict_content_get_impl(ZstdDict *self)
+/*[clinic end generated code: output=0d05caa5b550eabb input=4ed526d1c151c596]*/
+{
+ return PyBytes_FromStringAndSize(self->dict_buffer, self->dict_len);
+}
+
+/*[clinic input]
@getter
_zstd.ZstdDict.as_digested_dict
Load as a digested dictionary to compressor.
-Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digested_dict)
+Pass this attribute as zstd_dict argument:
+compress(dat, zstd_dict=zd.as_digested_dict)
+
1. Some advanced compression parameters of compressor may be overridden
by parameters of digested dictionary.
2. ZstdDict has a digested dictionaries cache for each compression level.
@@ -160,19 +171,20 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_digeste
static PyObject *
_zstd_ZstdDict_as_digested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=09b086e7a7320dbb input=585448c79f31f74a]*/
+/*[clinic end generated code: output=09b086e7a7320dbb input=ee45e1b4a48f6f2c]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_DIGESTED);
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDict.as_undigested_dict
Load as an undigested dictionary to compressor.
-Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undigested_dict)
+Pass this attribute as zstd_dict argument:
+compress(dat, zstd_dict=zd.as_undigested_dict)
+
1. The advanced compression parameters of compressor will not be overridden.
2. Loading an undigested dictionary is costly. If load an undigested dictionary
multiple times, consider reusing a compressor object.
@@ -181,19 +193,20 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_undiges
static PyObject *
_zstd_ZstdDict_as_undigested_dict_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=43c7a989e6d4253a input=022b0829ffb1c220]*/
+/*[clinic end generated code: output=43c7a989e6d4253a input=d39210eedec76fed]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_UNDIGESTED);
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDict.as_prefix
Load as a prefix to compressor/decompressor.
-Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix)
+Pass this attribute as zstd_dict argument:
+compress(dat, zstd_dict=zd.as_prefix)
+
1. Prefix is compatible with long distance matching, while dictionary is not.
2. It only works for the first frame, then the compressor/decompressor will
return to no prefix state.
@@ -202,12 +215,13 @@ Pass this attribute as zstd_dict argument: compress(dat, zstd_dict=zd.as_prefix)
static PyObject *
_zstd_ZstdDict_as_prefix_get_impl(ZstdDict *self)
-/*[clinic end generated code: output=6f7130c356595a16 input=09fb82a6a5407e87]*/
+/*[clinic end generated code: output=6f7130c356595a16 input=d59757b0b5a9551a]*/
{
return Py_BuildValue("Oi", self, DICT_TYPE_PREFIX);
}
static PyGetSetDef ZstdDict_getset[] = {
+ _ZSTD_ZSTDDICT_DICT_CONTENT_GETSETDEF
_ZSTD_ZSTDDICT_AS_DIGESTED_DICT_GETSETDEF
_ZSTD_ZSTDDICT_AS_UNDIGESTED_DICT_GETSETDEF
_ZSTD_ZSTDDICT_AS_PREFIX_GETSETDEF
@@ -218,8 +232,7 @@ static Py_ssize_t
ZstdDict_length(PyObject *ob)
{
ZstdDict *self = ZstdDict_CAST(ob);
- assert(PyBytes_Check(self->dict_content));
- return Py_SIZE(self->dict_content);
+ return self->dict_len;
}
static int
@@ -227,7 +240,6 @@ ZstdDict_traverse(PyObject *ob, visitproc visit, void *arg)
{
ZstdDict *self = ZstdDict_CAST(ob);
Py_VISIT(self->c_dicts);
- Py_VISIT(self->dict_content);
return 0;
}
@@ -235,7 +247,7 @@ static int
ZstdDict_clear(PyObject *ob)
{
ZstdDict *self = ZstdDict_CAST(ob);
- Py_CLEAR(self->dict_content);
+ Py_CLEAR(self->c_dicts);
return 0;
}
@@ -244,7 +256,7 @@ static PyType_Slot zstddict_slots[] = {
{Py_tp_getset, ZstdDict_getset},
{Py_tp_new, _zstd_ZstdDict_new},
{Py_tp_dealloc, ZstdDict_dealloc},
- {Py_tp_str, ZstdDict_str},
+ {Py_tp_repr, ZstdDict_repr},
{Py_tp_doc, (void *)_zstd_ZstdDict_new__doc__},
{Py_sq_length, ZstdDict_length},
{Py_tp_traverse, ZstdDict_traverse},
diff --git a/Modules/_zstd/zstddict.h b/Modules/_zstd/zstddict.h
index e8a55a3670b..4a403416dbd 100644
--- a/Modules/_zstd/zstddict.h
+++ b/Modules/_zstd/zstddict.h
@@ -15,10 +15,15 @@ typedef struct {
ZSTD_DDict *d_dict;
PyObject *c_dicts;
- /* Content of the dictionary, bytes object. */
- PyObject *dict_content;
+ /* Dictionary content. */
+ char *dict_buffer;
+ Py_ssize_t dict_len;
+
/* Dictionary id */
uint32_t dict_id;
+
+ /* Lock to protect the digested dictionaries */
+ PyMutex lock;
} ZstdDict;
#endif // !ZSTD_DICT_H