aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Modules/_zstd/decompressor.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_zstd/decompressor.c')
-rw-r--r--Modules/_zstd/decompressor.c187
1 files changed, 98 insertions, 89 deletions
diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c
index 58f9c9f804e..d084f0847c7 100644
--- a/Modules/_zstd/decompressor.c
+++ b/Modules/_zstd/decompressor.c
@@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec"
#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
+#include "internal/pycore_lock.h" // PyMutex_IsLocked
#include <stdbool.h> // bool
#include <stddef.h> // offsetof()
@@ -45,6 +46,9 @@ typedef struct {
/* For ZstdDecompressor, 0 or 1.
1 means the end of the first frame has been reached. */
bool eof;
+
+ /* Lock to protect the decompression context */
+ PyMutex lock;
} ZstdDecompressor;
#define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op)
@@ -54,6 +58,7 @@ typedef struct {
static inline ZSTD_DDict *
_get_DDict(ZstdDict *self)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_DDict *ret;
/* Already created */
@@ -61,18 +66,17 @@ _get_DDict(ZstdDict *self)
return self->d_dict;
}
- Py_BEGIN_CRITICAL_SECTION(self);
if (self->d_dict == NULL) {
/* Create ZSTD_DDict instance from dictionary content */
char *dict_buffer = PyBytes_AS_STRING(self->dict_content);
Py_ssize_t dict_len = Py_SIZE(self->dict_content);
Py_BEGIN_ALLOW_THREADS
- self->d_dict = ZSTD_createDDict(dict_buffer,
- dict_len);
+ ret = ZSTD_createDDict(dict_buffer, dict_len);
Py_END_ALLOW_THREADS
+ self->d_dict = ret;
if (self->d_dict == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Failed to create a ZSTD_DDict instance from "
@@ -81,11 +85,7 @@ _get_DDict(ZstdDict *self)
}
}
- /* Don't lose any exception */
- ret = self->d_dict;
- Py_END_CRITICAL_SECTION();
-
- return ret;
+ return self->d_dict;
}
/* Set decompression parameters to decompression context */
@@ -95,7 +95,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
size_t zstd_ret;
PyObject *key, *value;
Py_ssize_t pos;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}
@@ -112,7 +112,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
if (Py_TYPE(key) == mod_state->CParameter_type) {
PyErr_SetString(PyExc_TypeError,
"Key of decompression options dict should "
- "NOT be CompressionParameter.");
+ "NOT be a CompressionParameter attribute.");
return -1;
}
@@ -120,12 +120,11 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
int key_v = PyLong_AsInt(key);
if (key_v == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError,
- "Key of options dict should be a DecompressionParameter attribute.");
+ "Key of options dict should be either a "
+ "DecompressionParameter attribute or an int.");
return -1;
}
- // TODO(emmatyping): check bounds when there is a value error here for better
- // error message?
int value_v = PyLong_AsInt(value);
if (value_v == -1 && PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError,
@@ -134,9 +133,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
}
/* Set parameter to compression context */
- Py_BEGIN_CRITICAL_SECTION(self);
zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v);
- Py_END_CRITICAL_SECTION();
/* Check error */
if (ZSTD_isError(zstd_ret)) {
@@ -147,12 +144,54 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options)
return 0;
}
+static int
+_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd,
+ _zstd_state *mod_state, int type)
+{
+ size_t zstd_ret;
+ if (type == DICT_TYPE_DIGESTED) {
+ /* Get ZSTD_DDict */
+ ZSTD_DDict *d_dict = _get_DDict(zd);
+ if (d_dict == NULL) {
+ return -1;
+ }
+ /* Reference a prepared dictionary */
+ zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
+ }
+ else if (type == DICT_TYPE_UNDIGESTED) {
+ /* Load a dictionary */
+ zstd_ret = ZSTD_DCtx_loadDictionary(
+ self->dctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else if (type == DICT_TYPE_PREFIX) {
+ /* Load a prefix */
+ zstd_ret = ZSTD_DCtx_refPrefix(
+ self->dctx,
+ PyBytes_AS_STRING(zd->dict_content),
+ Py_SIZE(zd->dict_content));
+ }
+ else {
+ /* Impossible code path */
+ PyErr_SetString(PyExc_SystemError,
+ "load_d_dict() impossible code path");
+ return -1;
+ }
+
+ /* Check error */
+ if (ZSTD_isError(zstd_ret)) {
+ set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
+ return -1;
+ }
+ return 0;
+}
+
/* Load dictionary or prefix to decompression context */
static int
_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
{
- size_t zstd_ret;
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
return -1;
}
@@ -168,7 +207,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
/* When decompressing, use digested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_DIGESTED;
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
/* Check (ZstdDict, type) */
@@ -182,13 +224,16 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
else if (ret > 0) {
/* type == -1 may indicate an error. */
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
- if (type == DICT_TYPE_DIGESTED ||
- type == DICT_TYPE_UNDIGESTED ||
- type == DICT_TYPE_PREFIX)
+ if (type == DICT_TYPE_DIGESTED
+ || type == DICT_TYPE_UNDIGESTED
+ || type == DICT_TYPE_PREFIX)
{
assert(type >= 0);
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
- goto load;
+ PyMutex_Lock(&zd->lock);
+ ret = _zstd_load_impl(self, zd, mod_state, type);
+ PyMutex_Unlock(&zd->lock);
+ return ret;
}
}
}
@@ -197,50 +242,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return -1;
-
-load:
- if (type == DICT_TYPE_DIGESTED) {
- /* Get ZSTD_DDict */
- ZSTD_DDict *d_dict = _get_DDict(zd);
- if (d_dict == NULL) {
- return -1;
- }
- /* Reference a prepared dictionary */
- Py_BEGIN_CRITICAL_SECTION(self);
- zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict);
- Py_END_CRITICAL_SECTION();
- }
- else if (type == DICT_TYPE_UNDIGESTED) {
- /* Load a dictionary */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_loadDictionary(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else if (type == DICT_TYPE_PREFIX) {
- /* Load a prefix */
- Py_BEGIN_CRITICAL_SECTION2(self, zd);
- zstd_ret = ZSTD_DCtx_refPrefix(
- self->dctx,
- PyBytes_AS_STRING(zd->dict_content),
- Py_SIZE(zd->dict_content));
- Py_END_CRITICAL_SECTION2();
- }
- else {
- /* Impossible code path */
- PyErr_SetString(PyExc_SystemError,
- "load_d_dict() impossible code path");
- return -1;
- }
-
- /* Check error */
- if (ZSTD_isError(zstd_ret)) {
- set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret);
- return -1;
- }
- return 0;
}
/*
@@ -268,8 +269,8 @@ load:
Note, decompressing "an empty input" in any case will make it > 0.
*/
static PyObject *
-decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
- Py_ssize_t max_length)
+decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in,
+ Py_ssize_t max_length)
{
size_t zstd_ret;
ZSTD_outBuffer out;
@@ -290,7 +291,7 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in,
/* Check error */
if (ZSTD_isError(zstd_ret)) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret);
}
@@ -339,10 +340,9 @@ error:
}
static void
-decompressor_reset_session(ZstdDecompressor *self)
+decompressor_reset_session_lock_held(ZstdDecompressor *self)
{
- // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here
- // and ensure lock is always held
+ assert(PyMutex_IsLocked(&self->lock));
/* Reset variables */
self->in_begin = 0;
@@ -359,15 +359,18 @@ decompressor_reset_session(ZstdDecompressor *self)
}
static PyObject *
-stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length)
+stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data,
+ Py_ssize_t max_length)
{
+ assert(PyMutex_IsLocked(&self->lock));
ZSTD_inBuffer in;
PyObject *ret = NULL;
int use_input_buffer;
/* Check .eof flag */
if (self->eof) {
- PyErr_SetString(PyExc_EOFError, "Already at the end of a Zstandard frame.");
+ PyErr_SetString(PyExc_EOFError,
+ "Already at the end of a Zstandard frame.");
assert(ret == NULL);
return NULL;
}
@@ -456,7 +459,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
assert(in.pos == 0);
/* Decompress */
- ret = decompress_impl(self, &in, max_length);
+ ret = decompress_lock_held(self, &in, max_length);
if (ret == NULL) {
goto error;
}
@@ -484,8 +487,8 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
if (!use_input_buffer) {
/* Discard buffer if it's too small
(resizing it may needlessly copy the current contents) */
- if (self->input_buffer != NULL &&
- self->input_buffer_size < data_size)
+ if (self->input_buffer != NULL
+ && self->input_buffer_size < data_size)
{
PyMem_Free(self->input_buffer);
self->input_buffer = NULL;
@@ -517,7 +520,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length
error:
/* Reset decompressor's states/session */
- decompressor_reset_session(self);
+ decompressor_reset_session_lock_held(self);
Py_CLEAR(ret);
return NULL;
@@ -555,6 +558,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
self->unused_data = NULL;
self->eof = 0;
self->dict = NULL;
+ self->lock = (PyMutex){0};
/* needs_input flag */
self->needs_input = 1;
@@ -562,7 +566,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict,
/* Decompression context */
self->dctx = ZSTD_createDCtx();
if (self->dctx == NULL) {
- _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self));
+ _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state != NULL) {
PyErr_SetString(mod_state->ZstdError,
"Unable to create ZSTD_DCtx instance.");
@@ -608,6 +612,8 @@ ZstdDecompressor_dealloc(PyObject *ob)
ZSTD_freeDCtx(self->dctx);
}
+ assert(!PyMutex_IsLocked(&self->lock));
+
/* Py_CLEAR the dict after free decompression context */
Py_CLEAR(self->dict);
@@ -623,7 +629,6 @@ ZstdDecompressor_dealloc(PyObject *ob)
}
/*[clinic input]
-@critical_section
@getter
_zstd.ZstdDecompressor.unused_data
@@ -635,11 +640,14 @@ decompressed, unused input data after the frame. Otherwise this will be b''.
static PyObject *
_zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
-/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/
+/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/
{
PyObject *ret;
+ PyMutex_Lock(&self->lock);
+
if (!self->eof) {
+ PyMutex_Unlock(&self->lock);
return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES);
}
else {
@@ -656,6 +664,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self)
}
}
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -693,10 +702,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self,
{
PyObject *ret;
/* Thread-safe code */
- Py_BEGIN_CRITICAL_SECTION(self);
-
- ret = stream_decompress(self, data, max_length);
- Py_END_CRITICAL_SECTION();
+ PyMutex_Lock(&self->lock);
+ ret = stream_decompress_lock_held(self, data, max_length);
+ PyMutex_Unlock(&self->lock);
return ret;
}
@@ -710,9 +718,10 @@ PyDoc_STRVAR(ZstdDecompressor_eof_doc,
"after that, an EOFError exception will be raised.");
PyDoc_STRVAR(ZstdDecompressor_needs_input_doc,
-"If the max_length output limit in .decompress() method has been reached, and\n"
-"the decompressor has (or may has) unconsumed input data, it will be set to\n"
-"False. In this case, pass b'' to .decompress() method may output further data.");
+"If the max_length output limit in .decompress() method has been reached,\n"
+"and the decompressor has (or may has) unconsumed input data, it will be set\n"
+"to False. In this case, passing b'' to the .decompress() method may output\n"
+"further data.");
static PyMemberDef ZstdDecompressor_members[] = {
{"eof", Py_T_BOOL, offsetof(ZstdDecompressor, eof),