diff options
Diffstat (limited to 'Modules/_zstd/compressor.c')
-rw-r--r-- | Modules/_zstd/compressor.c | 205 |
1 files changed, 110 insertions, 95 deletions
diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 38baee2be1e..31cb8c535c0 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -17,6 +17,7 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" #include "zstddict.h" +#include "internal/pycore_lock.h" // PyMutex_IsLocked #include <stddef.h> // offsetof() #include <zstd.h> // ZSTD_*() @@ -38,6 +39,9 @@ typedef struct { /* Compression level */ int compression_level; + + /* Lock to protect the compression context */ + PyMutex lock; } ZstdCompressor; #define ZstdCompressor_CAST(op) ((ZstdCompressor *)op) @@ -49,7 +53,7 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, const char *arg_name, const char* arg_type) { size_t zstd_ret; - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } @@ -59,8 +63,8 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, int level = PyLong_AsInt(level_or_options); if (level == -1 && PyErr_Occurred()) { PyErr_Format(PyExc_ValueError, - "Compression level should be an int value between %d and %d.", - ZSTD_minCLevel(), ZSTD_maxCLevel()); + "Compression level should be an int value between " + "%d and %d.", ZSTD_minCLevel(), ZSTD_maxCLevel()); return -1; } @@ -89,24 +93,23 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, /* Check key type */ if (Py_TYPE(key) == mod_state->DParameter_type) { PyErr_SetString(PyExc_TypeError, - "Key of compression option dict should " - "NOT be DecompressionParameter."); + "Key of compression options dict should " + "NOT be a DecompressionParameter attribute."); return -1; } int key_v = PyLong_AsInt(key); if (key_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, - "Key of options dict should be a CompressionParameter attribute."); + "Key of options dict should be either a " + "CompressionParameter attribute or an int."); return -1; } - // TODO(emmatyping): check bounds when there is a value error here for better - // error message? int value_v = PyLong_AsInt(value); if (value_v == -1 && PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, - "Value of option dict should be an int."); + "Value of options dict should be an int."); return -1; } @@ -135,7 +138,8 @@ _zstd_set_c_parameters(ZstdCompressor *self, PyObject *level_or_options, } return 0; } - PyErr_Format(PyExc_TypeError, "Invalid type for %s. Expected %s", arg_name, arg_type); + PyErr_Format(PyExc_TypeError, + "Invalid type for %s. Expected %s", arg_name, arg_type); return -1; } @@ -149,12 +153,12 @@ capsule_free_cdict(PyObject *capsule) ZSTD_CDict * _get_CDict(ZstdDict *self, int compressionLevel) { + assert(PyMutex_IsLocked(&self->lock)); PyObject *level = NULL; - PyObject *capsule; + PyObject *capsule = NULL; ZSTD_CDict *cdict; + int ret; - // TODO(emmatyping): refactor critical section code into a lock_held function - Py_BEGIN_CRITICAL_SECTION(self); /* int level object */ level = PyLong_FromLong(compressionLevel); @@ -163,12 +167,11 @@ _get_CDict(ZstdDict *self, int compressionLevel) } /* Get PyCapsule object from self->c_dicts */ - capsule = PyDict_GetItemWithError(self->c_dicts, level); + ret = PyDict_GetItemRef(self->c_dicts, level, &capsule); + if (ret < 0) { + goto error; + } if (capsule == NULL) { - if (PyErr_Occurred()) { - goto error; - } - /* Create ZSTD_CDict instance */ char *dict_buffer = PyBytes_AS_STRING(self->dict_content); Py_ssize_t dict_len = Py_SIZE(self->dict_content); @@ -179,7 +182,7 @@ _get_CDict(ZstdDict *self, int compressionLevel) Py_END_ALLOW_THREADS if (cdict == NULL) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { PyErr_SetString(mod_state->ZstdError, "Failed to create a ZSTD_CDict instance from " @@ -196,11 +199,10 @@ _get_CDict(ZstdDict *self, int compressionLevel) } /* Add PyCapsule object to self->c_dicts */ - if (PyDict_SetItem(self->c_dicts, level, capsule) < 0) { - Py_DECREF(capsule); + ret = PyDict_SetItem(self->c_dicts, level, capsule); + if (ret < 0) { goto error; } - Py_DECREF(capsule); } else { /* ZSTD_CDict instance already exists */ @@ -212,16 +214,56 @@ error: cdict = NULL; success: Py_XDECREF(level); - Py_END_CRITICAL_SECTION(); + Py_XDECREF(capsule); return cdict; } static int -_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) +_zstd_load_impl(ZstdCompressor *self, ZstdDict *zd, + _zstd_state *mod_state, int type) { - size_t zstd_ret; - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + if (type == DICT_TYPE_DIGESTED) { + /* Get ZSTD_CDict */ + ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level); + if (c_dict == NULL) { + return -1; + } + /* Reference a prepared dictionary. + It overrides some compression context's parameters. */ + zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict); + } + else if (type == DICT_TYPE_UNDIGESTED) { + /* Load a dictionary. + It doesn't override compression context's parameters. */ + zstd_ret = ZSTD_CCtx_loadDictionary( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else if (type == DICT_TYPE_PREFIX) { + /* Load a prefix */ + zstd_ret = ZSTD_CCtx_refPrefix( + self->cctx, + PyBytes_AS_STRING(zd->dict_content), + Py_SIZE(zd->dict_content)); + } + else { + Py_UNREACHABLE(); + } + + /* Check error */ + if (ZSTD_isError(zstd_ret)) { + set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret); + return -1; + } + return 0; +} + +static int +_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) +{ + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state == NULL) { return -1; } @@ -237,7 +279,10 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) /* When compressing, use undigested dictionary by default. */ zd = (ZstdDict*)dict; type = DICT_TYPE_UNDIGESTED; - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* Check (ZstdDict, type) */ @@ -251,13 +296,16 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) else if (ret > 0) { /* type == -1 may indicate an error. */ type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED || - type == DICT_TYPE_UNDIGESTED || - type == DICT_TYPE_PREFIX) + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) { assert(type >= 0); zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - goto load; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } } } @@ -266,49 +314,6 @@ _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) PyErr_SetString(PyExc_TypeError, "zstd_dict argument should be ZstdDict object."); return -1; - -load: - if (type == DICT_TYPE_DIGESTED) { - /* Get ZSTD_CDict */ - ZSTD_CDict *c_dict = _get_CDict(zd, self->compression_level); - if (c_dict == NULL) { - return -1; - } - /* Reference a prepared dictionary. - It overrides some compression context's parameters. */ - Py_BEGIN_CRITICAL_SECTION(self); - zstd_ret = ZSTD_CCtx_refCDict(self->cctx, c_dict); - Py_END_CRITICAL_SECTION(); - } - else if (type == DICT_TYPE_UNDIGESTED) { - /* Load a dictionary. - It doesn't override compression context's parameters. */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_CCtx_loadDictionary( - self->cctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else if (type == DICT_TYPE_PREFIX) { - /* Load a prefix */ - Py_BEGIN_CRITICAL_SECTION2(self, zd); - zstd_ret = ZSTD_CCtx_refPrefix( - self->cctx, - PyBytes_AS_STRING(zd->dict_content), - Py_SIZE(zd->dict_content)); - Py_END_CRITICAL_SECTION2(); - } - else { - Py_UNREACHABLE(); - } - - /* Check error */ - if (ZSTD_isError(zstd_ret)) { - set_zstd_error(mod_state, ERR_LOAD_C_DICT, zstd_ret); - return -1; - } - return 0; } /*[clinic input] @@ -339,11 +344,12 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level, self->use_multithread = 0; self->dict = NULL; + self->lock = (PyMutex){0}; /* Compression context */ self->cctx = ZSTD_createCCtx(); if (self->cctx == NULL) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { PyErr_SetString(mod_state->ZstdError, "Unable to create ZSTD_CCtx instance."); @@ -355,7 +361,8 @@ _zstd_ZstdCompressor_new_impl(PyTypeObject *type, PyObject *level, self->last_mode = ZSTD_e_end; if (level != Py_None && options != Py_None) { - PyErr_SetString(PyExc_RuntimeError, "Only one of level or options should be used."); + PyErr_SetString(PyExc_RuntimeError, + "Only one of level or options should be used."); goto error; } @@ -403,6 +410,8 @@ ZstdCompressor_dealloc(PyObject *ob) ZSTD_freeCCtx(self->cctx); } + assert(!PyMutex_IsLocked(&self->lock)); + /* Py_XDECREF the dict after free the compression context */ Py_CLEAR(self->dict); @@ -412,9 +421,10 @@ ZstdCompressor_dealloc(PyObject *ob) } static PyObject * -compress_impl(ZstdCompressor *self, Py_buffer *data, - ZSTD_EndDirective end_directive) +compress_lock_held(ZstdCompressor *self, Py_buffer *data, + ZSTD_EndDirective end_directive) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; ZSTD_outBuffer out; _BlocksOutputBuffer buffer = {.list = NULL}; @@ -441,7 +451,7 @@ compress_impl(ZstdCompressor *self, Py_buffer *data, } if (_OutputBuffer_InitWithSize(&buffer, &out, -1, - (Py_ssize_t) output_buffer_size) < 0) { + (Py_ssize_t) output_buffer_size) < 0) { goto error; } @@ -454,7 +464,7 @@ compress_impl(ZstdCompressor *self, Py_buffer *data, /* Check error */ if (ZSTD_isError(zstd_ret)) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); } @@ -495,8 +505,9 @@ mt_continue_should_break(ZSTD_inBuffer *in, ZSTD_outBuffer *out) #endif static PyObject * -compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data) +compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data) { + assert(PyMutex_IsLocked(&self->lock)); ZSTD_inBuffer in; ZSTD_outBuffer out; _BlocksOutputBuffer buffer = {.list = NULL}; @@ -516,20 +527,23 @@ compress_mt_continue_impl(ZstdCompressor *self, Py_buffer *data) while (1) { Py_BEGIN_ALLOW_THREADS do { - zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in, ZSTD_e_continue); - } while (out.pos != out.size && in.pos != in.size && !ZSTD_isError(zstd_ret)); + zstd_ret = ZSTD_compressStream2(self->cctx, &out, &in, + ZSTD_e_continue); + } while (out.pos != out.size + && in.pos != in.size + && !ZSTD_isError(zstd_ret)); Py_END_ALLOW_THREADS /* Check error */ if (ZSTD_isError(zstd_ret)) { - _zstd_state* const mod_state = PyType_GetModuleState(Py_TYPE(self)); + _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); if (mod_state != NULL) { set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); } goto error; } - /* Like compress_impl(), output as much as possible. */ + /* Like compress_lock_held(), output as much as possible. */ if (out.pos == out.size) { if (_OutputBuffer_Grow(&buffer, &out) < 0) { goto error; @@ -588,14 +602,14 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, } /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); + PyMutex_Lock(&self->lock); /* Compress */ if (self->use_multithread && mode == ZSTD_e_continue) { - ret = compress_mt_continue_impl(self, data); + ret = compress_mt_continue_lock_held(self, data); } else { - ret = compress_impl(self, data, mode); + ret = compress_lock_held(self, data, mode); } if (ret) { @@ -607,7 +621,7 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); + PyMutex_Unlock(&self->lock); return ret; } @@ -642,8 +656,9 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) } /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - ret = compress_impl(self, NULL, mode); + PyMutex_Lock(&self->lock); + + ret = compress_lock_held(self, NULL, mode); if (ret) { self->last_mode = mode; @@ -654,7 +669,7 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); + PyMutex_Unlock(&self->lock); return ret; } @@ -668,12 +683,12 @@ static PyMethodDef ZstdCompressor_methods[] = { PyDoc_STRVAR(ZstdCompressor_last_mode_doc, "The last mode used to this compressor object, its value can be .CONTINUE,\n" ".FLUSH_BLOCK, .FLUSH_FRAME. Initialized to .FLUSH_FRAME.\n\n" -"It can be used to get the current state of a compressor, such as, data flushed,\n" -"a frame ended."); +"It can be used to get the current state of a compressor, such as, data\n" +"flushed, or a frame ended."); static PyMemberDef ZstdCompressor_members[] = { {"last_mode", Py_T_INT, offsetof(ZstdCompressor, last_mode), - Py_READONLY, ZstdCompressor_last_mode_doc}, + Py_READONLY, ZstdCompressor_last_mode_doc}, {NULL} }; |