diff options
author | Eric Snow <ericsnowcurrently@gmail.com> | 2025-04-30 17:34:05 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-30 17:34:05 -0600 |
commit | cb35c11d82efd2959bda0397abcc1719bf6bb0cb (patch) | |
tree | 24401105bae0254b93e5d31488fa0a814d0f90be /Python/crossinterp.c | |
parent | 6c522debc218d441756bf631abe8ec8d6c6f1c45 (diff) | |
download | cpython-cb35c11d82efd2959bda0397abcc1719bf6bb0cb.tar.gz cpython-cb35c11d82efd2959bda0397abcc1719bf6bb0cb.zip |
gh-132775: Add _PyPickle_GetXIData() (gh-133107)
There's some extra complexity due to making sure we we get things right when handling functions and classes defined in the __main__ module. This is also reflected in the tests, including the addition of extra functions in test.support.import_helper.
Diffstat (limited to 'Python/crossinterp.c')
-rw-r--r-- | Python/crossinterp.c | 452 |
1 files changed, 452 insertions, 0 deletions
diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 753d784a503..a9f9b785629 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -3,6 +3,7 @@ #include "Python.h" #include "marshal.h" // PyMarshal_WriteObjectToString() +#include "osdefs.h" // MAXPATHLEN #include "pycore_ceval.h" // _Py_simple_func #include "pycore_crossinterp.h" // _PyXIData_t #include "pycore_initconfig.h" // _PyStatus_OK() @@ -10,6 +11,155 @@ #include "pycore_typeobject.h" // _PyStaticType_InitBuiltin() +static Py_ssize_t +_Py_GetMainfile(char *buffer, size_t maxlen) +{ + // We don't expect subinterpreters to have the __main__ module's + // __name__ set, but proceed just in case. + PyThreadState *tstate = _PyThreadState_GET(); + PyObject *module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + return -1; + } + Py_ssize_t size = _PyModule_GetFilenameUTF8(module, buffer, maxlen); + Py_DECREF(module); + return size; +} + + +static PyObject * +import_get_module(PyThreadState *tstate, const char *modname) +{ + PyObject *module = NULL; + if (strcmp(modname, "__main__") == 0) { + module = _Py_GetMainModule(tstate); + if (_Py_CheckMainModule(module) < 0) { + assert(_PyErr_Occurred(tstate)); + return NULL; + } + } + else { + module = PyImport_ImportModule(modname); + if (module == NULL) { + return NULL; + } + } + return module; +} + + +static PyObject * +runpy_run_path(const char *filename, const char *modname) +{ + PyObject *run_path = PyImport_ImportModuleAttrString("runpy", "run_path"); + if (run_path == NULL) { + return NULL; + } + PyObject *args = Py_BuildValue("(sOs)", filename, Py_None, modname); + if (args == NULL) { + Py_DECREF(run_path); + return NULL; + } + PyObject *ns = PyObject_Call(run_path, args, NULL); + Py_DECREF(run_path); + Py_DECREF(args); + return ns; +} + + +static PyObject * +pyerr_get_message(PyObject *exc) +{ + assert(!PyErr_Occurred()); + PyObject *args = PyException_GetArgs(exc); + if (args == NULL || args == Py_None || PyObject_Size(args) < 1) { + return NULL; + } + if (PyUnicode_Check(args)) { + return args; + } + PyObject *msg = PySequence_GetItem(args, 0); + Py_DECREF(args); + if (msg == NULL) { + PyErr_Clear(); + return NULL; + } + if (!PyUnicode_Check(msg)) { + Py_DECREF(msg); + return NULL; + } + return msg; +} + +#define MAX_MODNAME (255) +#define MAX_ATTRNAME (255) + +struct attributeerror_info { + char modname[MAX_MODNAME+1]; + char attrname[MAX_ATTRNAME+1]; +}; + +static int +_parse_attributeerror(PyObject *exc, struct attributeerror_info *info) +{ + assert(exc != NULL); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + int res = -1; + + PyObject *msgobj = pyerr_get_message(exc); + if (msgobj == NULL) { + return -1; + } + const char *err = PyUnicode_AsUTF8(msgobj); + + if (strncmp(err, "module '", 8) != 0) { + goto finally; + } + err += 8; + + const char *matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + Py_ssize_t len = matched - err; + if (len > MAX_MODNAME) { + goto finally; + } + (void)strncpy(info->modname, err, len); + info->modname[len] = '\0'; + err = matched; + + if (strncmp(err, "' has no attribute '", 20) != 0) { + goto finally; + } + err += 20; + + matched = strchr(err, '\''); + if (matched == NULL) { + goto finally; + } + len = matched - err; + if (len > MAX_ATTRNAME) { + goto finally; + } + (void)strncpy(info->attrname, err, len); + info->attrname[len] = '\0'; + err = matched + 1; + + if (strlen(err) > 0) { + goto finally; + } + res = 0; + +finally: + Py_DECREF(msgobj); + return res; +} + +#undef MAX_MODNAME +#undef MAX_ATTRNAME + + /**************/ /* exceptions */ /**************/ @@ -287,6 +437,308 @@ _PyObject_GetXIData(PyThreadState *tstate, } +/* pickle C-API */ + +struct _pickle_context { + PyThreadState *tstate; +}; + +static PyObject * +_PyPickle_Dumps(struct _pickle_context *ctx, PyObject *obj) +{ + PyObject *dumps = PyImport_ImportModuleAttrString("pickle", "dumps"); + if (dumps == NULL) { + return NULL; + } + PyObject *bytes = PyObject_CallOneArg(dumps, obj); + Py_DECREF(dumps); + return bytes; +} + + +struct sync_module_result { + PyObject *module; + PyObject *loaded; + PyObject *failed; +}; + +struct sync_module { + const char *filename; + char _filename[MAXPATHLEN+1]; + struct sync_module_result cached; +}; + +static void +sync_module_clear(struct sync_module *data) +{ + data->filename = NULL; + Py_CLEAR(data->cached.module); + Py_CLEAR(data->cached.loaded); + Py_CLEAR(data->cached.failed); +} + + +struct _unpickle_context { + PyThreadState *tstate; + // We only special-case the __main__ module, + // since other modules behave consistently. + struct sync_module main; +}; + +static void +_unpickle_context_clear(struct _unpickle_context *ctx) +{ + sync_module_clear(&ctx->main); +} + +static struct sync_module_result +_unpickle_context_get_module(struct _unpickle_context *ctx, + const char *modname) +{ + if (strcmp(modname, "__main__") == 0) { + return ctx->main.cached; + } + else { + return (struct sync_module_result){ + .failed = PyExc_NotImplementedError, + }; + } +} + +static struct sync_module_result +_unpickle_context_set_module(struct _unpickle_context *ctx, + const char *modname) +{ + struct sync_module_result res = {0}; + struct sync_module_result *cached = NULL; + const char *filename = NULL; + if (strcmp(modname, "__main__") == 0) { + cached = &ctx->main.cached; + filename = ctx->main.filename; + } + else { + res.failed = PyExc_NotImplementedError; + goto finally; + } + + res.module = import_get_module(ctx->tstate, modname); + if (res.module == NULL) { + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + + if (filename == NULL) { + Py_CLEAR(res.module); + res.failed = PyExc_NotImplementedError; + goto finally; + } + res.loaded = runpy_run_path(filename, modname); + if (res.loaded == NULL) { + Py_CLEAR(res.module); + res.failed = _PyErr_GetRaisedException(ctx->tstate); + assert(res.failed != NULL); + goto finally; + } + +finally: + if (cached != NULL) { + assert(cached->module == NULL); + assert(cached->loaded == NULL); + assert(cached->failed == NULL); + *cached = res; + } + return res; +} + + +static int +_handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) +{ + // The caller must check if an exception is set or not when -1 is returned. + assert(!_PyErr_Occurred(ctx->tstate)); + assert(PyErr_GivenExceptionMatches(exc, PyExc_AttributeError)); + struct attributeerror_info info; + if (_parse_attributeerror(exc, &info) < 0) { + return -1; + } + + // Get the module. + struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); + if (mod.failed != NULL) { + // It must have failed previously. + return -1; + } + if (mod.module == NULL) { + mod = _unpickle_context_set_module(ctx, info.modname); + if (mod.failed != NULL) { + return -1; + } + assert(mod.module != NULL); + } + + // Bail out if it is unexpectedly set already. + if (PyObject_HasAttrString(mod.module, info.attrname)) { + return -1; + } + + // Try setting the attribute. + PyObject *value = NULL; + if (PyDict_GetItemStringRef(mod.loaded, info.attrname, &value) <= 0) { + return -1; + } + assert(value != NULL); + int res = PyObject_SetAttrString(mod.module, info.attrname, value); + Py_DECREF(value); + if (res < 0) { + return -1; + } + + return 0; +} + +static PyObject * +_PyPickle_Loads(struct _unpickle_context *ctx, PyObject *pickled) +{ + PyObject *loads = PyImport_ImportModuleAttrString("pickle", "loads"); + if (loads == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg(loads, pickled); + if (ctx != NULL) { + while (obj == NULL) { + assert(_PyErr_Occurred(ctx->tstate)); + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + // We leave other failures unhandled. + break; + } + // Try setting the attr if not set. + PyObject *exc = _PyErr_GetRaisedException(ctx->tstate); + if (_handle_unpickle_missing_attr(ctx, exc) < 0) { + // Any resulting exceptions are ignored + // in favor of the original. + _PyErr_SetRaisedException(ctx->tstate, exc); + break; + } + Py_CLEAR(exc); + // Retry with the attribute set. + obj = PyObject_CallOneArg(loads, pickled); + } + } + Py_DECREF(loads); + return obj; +} + + +/* pickle wrapper */ + +struct _pickle_xid_context { + // __main__.__file__ + struct { + const char *utf8; + size_t len; + char _utf8[MAXPATHLEN+1]; + } mainfile; +}; + +static int +_set_pickle_xid_context(PyThreadState *tstate, struct _pickle_xid_context *ctx) +{ + // Set mainfile if possible. + Py_ssize_t len = _Py_GetMainfile(ctx->mainfile._utf8, MAXPATHLEN); + if (len < 0) { + // For now we ignore any exceptions. + PyErr_Clear(); + } + else if (len > 0) { + ctx->mainfile.utf8 = ctx->mainfile._utf8; + ctx->mainfile.len = (size_t)len; + } + + return 0; +} + + +struct _shared_pickle_data { + _PyBytes_data_t pickled; // Must be first if we use _PyBytes_FromXIData(). + struct _pickle_xid_context ctx; +}; + +PyObject * +_PyPickle_LoadFromXIData(_PyXIData_t *xidata) +{ + PyThreadState *tstate = _PyThreadState_GET(); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)xidata->data; + // We avoid copying the pickled data by wrapping it in a memoryview. + // The alternative is to get a bytes object using _PyBytes_FromXIData(). + PyObject *pickled = PyMemoryView_FromMemory( + (char *)shared->pickled.bytes, shared->pickled.len, PyBUF_READ); + if (pickled == NULL) { + return NULL; + } + + // Unpickle the object. + struct _unpickle_context ctx = { + .tstate = tstate, + .main = { + .filename = shared->ctx.mainfile.utf8, + }, + }; + PyObject *obj = _PyPickle_Loads(&ctx, pickled); + Py_DECREF(pickled); + _unpickle_context_clear(&ctx); + if (obj == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be unpickled", cause); + Py_DECREF(cause); + } + return obj; +} + + +int +_PyPickle_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata) +{ + // Pickle the object. + struct _pickle_context ctx = { + .tstate = tstate, + }; + PyObject *bytes = _PyPickle_Dumps(&ctx, obj); + if (bytes == NULL) { + PyObject *cause = _PyErr_GetRaisedException(tstate); + assert(cause != NULL); + _set_xid_lookup_failure( + tstate, NULL, "object could not be pickled", cause); + Py_DECREF(cause); + return -1; + } + + // If we had an "unwrapper" mechnanism, we could call + // _PyObject_GetXIData() on the bytes object directly and add + // a simple unwrapper to call pickle.loads() on the bytes. + size_t size = sizeof(struct _shared_pickle_data); + struct _shared_pickle_data *shared = + (struct _shared_pickle_data *)_PyBytes_GetXIDataWrapped( + tstate, bytes, size, _PyPickle_LoadFromXIData, xidata); + Py_DECREF(bytes); + if (shared == NULL) { + return -1; + } + + // If it mattered, we could skip getting __main__.__file__ + // when "__main__" doesn't show up in the pickle bytes. + if (_set_pickle_xid_context(tstate, &shared->ctx) < 0) { + _xidata_clear(xidata); + return -1; + } + + return 0; +} + + /* marshal wrapper */ PyObject * |