diff options
Diffstat (limited to 'Modules/_zstd/decompressor.c')
-rw-r--r-- | Modules/_zstd/decompressor.c | 187 |
1 files changed, 98 insertions, 89 deletions
diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 58f9c9f804e..d084f0847c7 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -17,6 +17,7 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" #include "zstddict.h" +#include "internal/pycore_lock.h" // PyMutex_IsLocked #include <stdbool.h> // bool #include <stddef.h> // offsetof() @@ -45,6 +46,9 @@ typedef struct { /* For ZstdDecompressor, 0 or 1. 1 means the end of the first frame has been reached. */ bool eof; + + /* Lock to protect the decompression context */ + PyMutex lock; } ZstdDecompressor; #define ZstdDecompressor_CAST(op) ((ZstdDecompressor *)op) @@ -54,6 +58,7 @@ typedef struct { static inline ZSTD_DDict * _get_DDict(ZstdDict *self) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_DDict *ret; /* Already created */ @@ -61,18 +66,17 @@ _get_DDict(ZstdDict *self) return self->d_dict; } - Py_BEGIN_CRITICAL_SECTION(self); if (self->d_dict == NULL) { /* Create ZSTD_DDict instance from dictionary content */ char *dict_buffer = PyBytes_AS_STRING(self->dict_content); Py_ssize_t dict_len = Py_SIZE(self->dict_content); Py_BEGIN_ALLOW_THREADS - self->d_dict = ZSTD_createDDict(dict_buffer, - dict_len); + ret = ZSTD_createDDict(dict_buffer, dict_len); Py_END_ALLOW_THREADS + self->d_dict = ret; if (self->d_dict == NULL) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { PyErr_SetString(mod_state->ZstdError, "Failed to create a ZSTD_DDict instance from " @@ -81,11 +85,7 @@ _get_DDict(ZstdDict *self) } } - /* Don't lose any exception */ - ret = self->d_dict; - Py_END_CRITICAL_SECTION(); - - return ret; + return self->d_dict; } /* Set decompression parameters to decompression context */ @@ -95,7 +95,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) size_t zstd_ret; PyObject *key, *value; Py_ssize_t pos; - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } @@ -112,7 +112,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) if (Py_TYPE(key) == mod_state->CParameter_type) { PyErr_SetString(PyExc_TypeError, "Key of decompression options dict should " - "NOT be CompressionParameter."); + "NOT be a CompressionParameter attribute."); return -1; } @@ -120,12 +120,11 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, - "Key of options dict should be a DecompressionParameter attribute."); + "Key of options dict should be either a " + "DecompressionParameter attribute or an int."); return -1; } - // TODO(emmatyping): check bounds when there is a value error here for better - // error message? int value_v = PyLong_AsInt(value); if (value_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, @@ -134,9 +133,7 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) } /* Set parameter to compression context */ - Py_BEGIN_CRITICAL_SECTION(self); zstd_ret = ZSTD_DCtx_setParameter(self->dctx, key_v, value_v); - Py_END_CRITICAL_SECTION(); /* Check error */ if (ZSTD_isError(zstd_ret)) { @@ -147,12 +144,54 @@ _zstd_set_d_parameters(ZstdDecompressor *self, PyObject *options) return 0; } +static int +_zstd_load_impl(ZstdDecompressor *self, ZstdDict *zd, + _zstd_state *mod_state, int type) +{ + size_t zstd_ret; + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_DDict */ + ZSTD_DDict *d_dict = _get_DDict(zd); + if (d_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary */ + zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary */ + zstd_ret = ZSTD_DCtx_loadDictionary( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + zstd_ret = ZSTD_DCtx_refPrefix( + self->dctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else { + /* Impossible code path */ + PyErr_SetString(PyExc_SystemError, + "load_d_dict() impossible code path"); + return -1; + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret); + return -1; + } + return 0; +} + /* Load dictionary or prefix to decompression context */ static int _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { - size_t zstd_ret; - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } @@ -168,7 +207,10 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) /* When decompressing, use digested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_DIGESTED; - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* Check (ZstdDict, type) */ @@ -182,13 +224,16 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) else if (ret > 0) { /* type == -1 may indicate an error. */ type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED || - type == DICT_TYPE_UNDIGESTED || - type == DICT_TYPE_PREFIX) + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) { assert(type >= 0); zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } } } @@ -197,50 +242,6 @@ _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) PyErr_SetString(PyExc_TypeError, "zstd_dict argument should be ZstdDict object."); return -1; - -load: - if (type == DICT_TYPE_DIGESTED) { - /* Get ZSTD_DDict */ - ZSTD_DDict *d_dict = _get_DDict(zd); - if (d_dict == NULL) { - return -1; - } - /* Reference a prepared dictionary */ - Py_BEGIN_CRITICAL_SECTION(self); - zstd_ret = ZSTD_DCtx_refDDict(self->dctx, d_dict); - Py_END_CRITICAL_SECTION(); - } - else if (type == DICT_TYPE_UNDIGESTED) { - /* Load a dictionary */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_DCtx_loadDictionary( - self->dctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else if (type == DICT_TYPE_PREFIX) { - /* Load a prefix */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_DCtx_refPrefix( - self->dctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else { - /* Impossible code path */ - PyErr_SetString(PyExc_SystemError, - "load_d_dict() impossible code path"); - return -1; - } - - /* Check error */ - if (ZSTD_isError(zstd_ret)) { - set_zstd_error(mod_state, ERR_LOAD_D_DICT, zstd_ret); - return -1; - } - return 0; } /* @@ -268,8 +269,8 @@ load: Note, decompressing "an empty input" in any case will make it > 0. */ static PyObject * -decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, - Py_ssize_t max_length) +decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in, + Py_ssize_t max_length) { size_t zstd_ret; ZSTD_outBuffer out; @@ -290,7 +291,7 @@ decompress_impl(ZstdDecompressor *self, ZSTD_inBuffer *in, /* Check error */ if (ZSTD_isError(zstd_ret)) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); } @@ -339,10 +340,9 @@ error: } static void -decompressor_reset_session(ZstdDecompressor *self) +decompressor_reset_session_lock_held(ZstdDecompressor *self) { - // TODO(emmatyping): use _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED here - // and ensure lock is always held + assert(PyMutex_IsLocked(&self->lock)); /* Reset variables */ self->in_begin = 0; @@ -359,15 +359,18 @@ decompressor_reset_session(ZstdDecompressor *self) } static PyObject * -stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length) +stream_decompress_lock_held(ZstdDecompressor *self, Py_buffer *data, + Py_ssize_t max_length) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; PyObject *ret = NULL; int use_input_buffer; /* Check .eof flag */ if (self->eof) { - PyErr_SetString(PyExc_EOFError, "Already at the end of a Zstandard frame."); + PyErr_SetString(PyExc_EOFError, + "Already at the end of a Zstandard frame."); assert(ret == NULL); return NULL; } @@ -456,7 +459,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length assert(in.pos == 0); /* Decompress */ - ret = decompress_impl(self, &in, max_length); + ret = decompress_lock_held(self, &in, max_length); if (ret == NULL) { goto error; } @@ -484,8 +487,8 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length if (!use_input_buffer) { /* Discard buffer if it's too small (resizing it may needlessly copy the current contents) */ - if (self->input_buffer != NULL && - self->input_buffer_size < data_size) + if (self->input_buffer != NULL + && self->input_buffer_size < data_size) { PyMem_Free(self->input_buffer); self->input_buffer = NULL; @@ -517,7 +520,7 @@ stream_decompress(ZstdDecompressor *self, Py_buffer *data, Py_ssize_t max_length error: /* Reset decompressor's states/session */ - decompressor_reset_session(self); + decompressor_reset_session_lock_held(self); Py_CLEAR(ret); return NULL; @@ -555,6 +558,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict, self->unused_data = NULL; self->eof = 0; self->dict = NULL; + self->lock = (PyMutex){0}; /* needs_input flag */ self->needs_input = 1; @@ -562,7 +566,7 @@ _zstd_ZstdDecompressor_new_impl(PyTypeObject *type, PyObject *zstd_dict, /* Decompression context */ self->dctx = ZSTD_createDCtx(); if (self->dctx == NULL) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { PyErr_SetString(mod_state->ZstdError, "Unable to create ZSTD_DCtx instance."); @@ -608,6 +612,8 @@ ZstdDecompressor_dealloc(PyObject *ob) ZSTD_freeDCtx(self->dctx); } + assert(!PyMutex_IsLocked(&self->lock)); + /* Py_CLEAR the dict after free decompression context */ Py_CLEAR(self->dict); @@ -623,7 +629,6 @@ ZstdDecompressor_dealloc(PyObject *ob) } /*[clinic input] -@critical_section @getter _zstd.ZstdDecompressor.unused_data @@ -635,11 +640,14 @@ decompressed, unused input data after the frame. Otherwise this will be b''. static PyObject * _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) -/*[clinic end generated code: output=f3a20940f11b6b09 input=5233800bef00df04]*/ +/*[clinic end generated code: output=f3a20940f11b6b09 input=54d41ecd681a3444]*/ { PyObject *ret; + PyMutex_Lock(&self->lock); + if (!self->eof) { + PyMutex_Unlock(&self->lock); return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES); } else { @@ -656,6 +664,7 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) } } + PyMutex_Unlock(&self->lock); return ret; } @@ -693,10 +702,9 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, { PyObject *ret; /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - - ret = stream_decompress(self, data, max_length); - Py_END_CRITICAL_SECTION(); + PyMutex_Lock(&self->lock); + ret = stream_decompress_lock_held(self, data, max_length); + PyMutex_Unlock(&self->lock); return ret; } @@ -710,9 +718,10 @@ PyDoc_STRVAR(ZstdDecompressor_eof_doc, "after that, an EOFError exception will be raised."); PyDoc_STRVAR(ZstdDecompressor_needs_input_doc, -"If the max_length output limit in .decompress() method has been reached, and\n" -"the decompressor has (or may has) unconsumed input data, it will be set to\n" -"False. In this case, pass b'' to .decompress() method may output further data."); +"If the max_length output limit in .decompress() method has been reached,\n" +"and the decompressor has (or may has) unconsumed input data, it will be set\n" +"to False. In this case, passing b'' to the .decompress() method may output\n" +"further data."); static PyMemberDef ZstdDecompressor_members[] = { {"eof", Py_T_BOOL, offsetof(ZstdDecompressor, eof), |