aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Modules/_zstd/_zstdmodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_zstd/_zstdmodule.c')
-rw-r--r--Modules/_zstd/_zstdmodule.c89
1 files changed, 61 insertions, 28 deletions
diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c
index 5ad697d2b83..d75c0779474 100644
--- a/Modules/_zstd/_zstdmodule.c
+++ b/Modules/_zstd/_zstdmodule.c
@@ -7,7 +7,6 @@
#include "Python.h"
#include "_zstdmodule.h"
-#include "zstddict.h"
#include <zstd.h> // ZSTD_*()
#include <zdict.h> // ZDICT_*()
@@ -20,14 +19,52 @@ module _zstd
#include "clinic/_zstdmodule.c.h"
+ZstdDict *
+_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
+{
+ if (state == NULL) {
+ return NULL;
+ }
+
+ /* Check ZstdDict */
+ if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
+ return (ZstdDict*)dict;
+ }
+
+ /* Check (ZstdDict, type) */
+ if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
+ && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
+ && PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
+ {
+ int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
+ if (type == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ if (type == DICT_TYPE_DIGESTED
+ || type == DICT_TYPE_UNDIGESTED
+ || type == DICT_TYPE_PREFIX)
+ {
+ *ptype = type;
+ return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
+ }
+ }
+
+ /* Wrong type */
+ PyErr_SetString(PyExc_TypeError,
+ "zstd_dict argument should be a ZstdDict object.");
+ return NULL;
+}
+
/* Format error message and set ZstdError. */
void
-set_zstd_error(const _zstd_state* const state,
- error_type type, size_t zstd_ret)
+set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
{
- char *msg;
+ const char *msg;
assert(ZSTD_isError(zstd_ret));
+ if (state == NULL) {
+ return;
+ }
switch (type) {
case ERR_DECOMPRESS:
msg = "Unable to decompress Zstandard data: %s";
@@ -35,6 +72,9 @@ set_zstd_error(const _zstd_state* const state,
case ERR_COMPRESS:
msg = "Unable to compress Zstandard data: %s";
break;
+ case ERR_SET_PLEDGED_INPUT_SIZE:
+ msg = "Unable to set pledged uncompressed content size: %s";
+ break;
case ERR_LOAD_D_DICT:
msg = "Unable to load Zstandard dictionary or prefix for "
@@ -174,7 +214,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
Py_ssize_t sizes_sum;
Py_ssize_t i;
- chunks_number = Py_SIZE(samples_sizes);
+ chunks_number = PyTuple_GET_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
@@ -188,20 +228,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
return -1;
}
- sizes_sum = 0;
+ sizes_sum = PyBytes_GET_SIZE(samples_bytes);
for (i = 0; i < chunks_number; i++) {
- PyObject *size = PyTuple_GetItem(samples_sizes, i);
- (*chunk_sizes)[i] = PyLong_AsSize_t(size);
- if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
- PyErr_Format(PyExc_ValueError,
- "Items in samples_sizes should be an int "
- "object, with a value between 0 and %u.", SIZE_MAX);
+ size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
+ (*chunk_sizes)[i] = size;
+ if (size == (size_t)-1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ goto sum_error;
+ }
return -1;
}
- sizes_sum += (*chunk_sizes)[i];
+ if ((size_t)sizes_sum < size) {
+ goto sum_error;
+ }
+ sizes_sum -= size;
}
- if (sizes_sum != Py_SIZE(samples_bytes)) {
+ if (sizes_sum != 0) {
+sum_error:
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the "
"concatenation's size.");
@@ -257,7 +301,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
/* Train the dictionary */
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
- char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
+ const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
Py_BEGIN_ALLOW_THREADS
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
samples_buffer,
@@ -507,20 +551,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
{
_zstd_state* mod_state = get_zstd_state(module);
- if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
- PyErr_SetString(PyExc_ValueError,
- "The two arguments should be CompressionParameter and "
- "DecompressionParameter types.");
- return NULL;
- }
-
- Py_XDECREF(mod_state->CParameter_type);
Py_INCREF(c_parameter_type);
- mod_state->CParameter_type = (PyTypeObject*)c_parameter_type;
-
- Py_XDECREF(mod_state->DParameter_type);
+ Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
Py_INCREF(d_parameter_type);
- mod_state->DParameter_type = (PyTypeObject*)d_parameter_type;
+ Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
Py_RETURN_NONE;
}
@@ -583,7 +617,6 @@ do { \
return -1;
}
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
- Py_DECREF(mod_state->ZstdError);
return -1;
}