aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_pyio.py23
-rw-r--r--Lib/asyncio/selector_events.py2
-rw-r--r--Lib/logging/__init__.py1
-rw-r--r--Lib/multiprocessing/connection.py2
-rw-r--r--Lib/multiprocessing/util.py79
-rw-r--r--Lib/pdb.py3
-rw-r--r--Lib/tempfile.py5
-rw-r--r--Lib/test/support/import_helper.py2
-rw-r--r--Lib/test/test__interpreters.py3
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py16
-rw-r--r--Lib/test/test_bufio.py2
-rw-r--r--Lib/test/test_capi/test_opt.py19
-rw-r--r--Lib/test/test_crossinterp.py511
-rw-r--r--Lib/test/test_embed.py4
-rw-r--r--Lib/test/test_io.py8
-rw-r--r--Lib/test/test_threading.py1
-rw-r--r--Lib/test/test_zipfile/test_core.py19
-rw-r--r--Lib/zipfile/__init__.py48
18 files changed, 633 insertions, 115 deletions
diff --git a/Lib/_pyio.py b/Lib/_pyio.py
index a870de5b532..fb2a6d049ca 100644
--- a/Lib/_pyio.py
+++ b/Lib/_pyio.py
@@ -407,6 +407,9 @@ class IOBase(metaclass=abc.ABCMeta):
if closed:
return
+ if dealloc_warn := getattr(self, "_dealloc_warn", None):
+ dealloc_warn(self)
+
# If close() fails, the caller logs the exception with
# sys.unraisablehook. close() must be called at the end at __del__().
self.close()
@@ -645,8 +648,6 @@ class RawIOBase(IOBase):
self._unsupported("write")
io.RawIOBase.register(RawIOBase)
-from _io import FileIO
-RawIOBase.register(FileIO)
class BufferedIOBase(IOBase):
@@ -853,6 +854,10 @@ class _BufferedIOMixin(BufferedIOBase):
else:
return "<{}.{} name={!r}>".format(modname, clsname, name)
+ def _dealloc_warn(self, source):
+ if dealloc_warn := getattr(self.raw, "_dealloc_warn", None):
+ dealloc_warn(source)
+
### Lower-level APIs ###
def fileno(self):
@@ -1563,7 +1568,8 @@ class FileIO(RawIOBase):
if not isinstance(fd, int):
raise TypeError('expected integer from opener')
if fd < 0:
- raise OSError('Negative file descriptor')
+ # bpo-27066: Raise a ValueError for bad value.
+ raise ValueError(f'opener returned {fd}')
owned_fd = fd
if not noinherit_flag:
os.set_inheritable(fd, False)
@@ -1600,12 +1606,11 @@ class FileIO(RawIOBase):
raise
self._fd = fd
- def __del__(self):
+ def _dealloc_warn(self, source):
if self._fd >= 0 and self._closefd and not self.closed:
import warnings
- warnings.warn('unclosed file %r' % (self,), ResourceWarning,
+ warnings.warn(f'unclosed file {source!r}', ResourceWarning,
stacklevel=2, source=self)
- self.close()
def __getstate__(self):
raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
@@ -1780,7 +1785,7 @@ class FileIO(RawIOBase):
if not self.closed:
self._stat_atopen = None
try:
- if self._closefd:
+ if self._closefd and self._fd >= 0:
os.close(self._fd)
finally:
super().close()
@@ -2689,6 +2694,10 @@ class TextIOWrapper(TextIOBase):
def newlines(self):
return self._decoder.newlines if self._decoder else None
+ def _dealloc_warn(self, source):
+ if dealloc_warn := getattr(self.buffer, "_dealloc_warn", None):
+ dealloc_warn(source)
+
class StringIO(TextIOWrapper):
"""Text I/O implementation using an in-memory buffer.
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 22147451fa7..6ad84044adf 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -173,7 +173,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
# See https://bugs.python.org/issue27906 for more details.
- for _ in range(backlog):
+ for _ in range(backlog + 1):
try:
conn, addr = sock.accept()
if self._debug:
diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py
index f2d1a02629d..5c3c4424934 100644
--- a/Lib/logging/__init__.py
+++ b/Lib/logging/__init__.py
@@ -591,6 +591,7 @@ class Formatter(object):
%(threadName)s Thread name (if available)
%(taskName)s Task name (if available)
%(process)d Process ID (if available)
+ %(processName)s Process name (if available)
%(message)s The result of record.getMessage(), computed just as
the record is emitted
"""
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py
index 5f288a8d393..fc00d286126 100644
--- a/Lib/multiprocessing/connection.py
+++ b/Lib/multiprocessing/connection.py
@@ -76,7 +76,7 @@ def arbitrary_address(family):
if family == 'AF_INET':
return ('localhost', 0)
elif family == 'AF_UNIX':
- return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir())
+ return tempfile.mktemp(prefix='sock-', dir=util.get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), next(_mmap_counter)), dir="")
diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py
index b7192042b9c..a1a537dd48d 100644
--- a/Lib/multiprocessing/util.py
+++ b/Lib/multiprocessing/util.py
@@ -19,7 +19,7 @@ from subprocess import _args_from_interpreter_flags # noqa: F401
from . import process
__all__ = [
- 'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger',
+ 'sub_debug', 'debug', 'info', 'sub_warning', 'warn', 'get_logger',
'log_to_stderr', 'get_temp_dir', 'register_after_fork',
'is_exiting', 'Finalize', 'ForkAwareThreadLock', 'ForkAwareLocal',
'close_all_fds_except', 'SUBDEBUG', 'SUBWARNING',
@@ -34,6 +34,7 @@ SUBDEBUG = 5
DEBUG = 10
INFO = 20
SUBWARNING = 25
+WARNING = 30
LOGGER_NAME = 'multiprocessing'
DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s'
@@ -53,6 +54,10 @@ def info(msg, *args):
if _logger:
_logger.log(INFO, msg, *args, stacklevel=2)
+def warn(msg, *args):
+ if _logger:
+ _logger.log(WARNING, msg, *args, stacklevel=2)
+
def sub_warning(msg, *args):
if _logger:
_logger.log(SUBWARNING, msg, *args, stacklevel=2)
@@ -121,6 +126,21 @@ abstract_sockets_supported = _platform_supports_abstract_sockets()
# Function returning a temp directory which will be removed on exit
#
+# Maximum length of a socket file path is usually between 92 and 108 [1],
+# but Linux is known to use a size of 108 [2]. BSD-based systems usually
+# use a size of 104 or 108 and Windows does not create AF_UNIX sockets.
+#
+# [1]: https://pubs.opengroup.org/onlinepubs/9799919799/basedefs/sys_un.h.html
+# [2]: https://man7.org/linux/man-pages/man7/unix.7.html.
+
+if sys.platform == 'linux':
+ _SUN_PATH_MAX = 108
+elif sys.platform.startswith(('openbsd', 'freebsd')):
+ _SUN_PATH_MAX = 104
+else:
+ # On Windows platforms, we do not create AF_UNIX sockets.
+ _SUN_PATH_MAX = None if os.name == 'nt' else 92
+
def _remove_temp_dir(rmtree, tempdir):
rmtree(tempdir)
@@ -130,12 +150,67 @@ def _remove_temp_dir(rmtree, tempdir):
if current_process is not None:
current_process._config['tempdir'] = None
+def _get_base_temp_dir(tempfile):
+ """Get a temporary directory where socket files will be created.
+
+ To prevent additional imports, pass a pre-imported 'tempfile' module.
+ """
+ if os.name == 'nt':
+ return None
+ # Most of the time, the default temporary directory is /tmp. Thus,
+ # listener sockets files "$TMPDIR/pymp-XXXXXXXX/sock-XXXXXXXX" do
+ # not have a path length exceeding SUN_PATH_MAX.
+ #
+ # If users specify their own temporary directory, we may be unable
+ # to create those files. Therefore, we fall back to the system-wide
+ # temporary directory /tmp, assumed to exist on POSIX systems.
+ #
+ # See https://github.com/python/cpython/issues/132124.
+ base_tempdir = tempfile.gettempdir()
+ # Files created in a temporary directory are suffixed by a string
+ # generated by tempfile._RandomNameSequence, which, by design,
+ # is 8 characters long.
+ #
+ # Thus, the length of socket filename will be:
+ #
+ # len(base_tempdir + '/pymp-XXXXXXXX' + '/sock-XXXXXXXX')
+ sun_path_len = len(base_tempdir) + 14 + 14
+ if sun_path_len <= _SUN_PATH_MAX:
+ return base_tempdir
+ # Fallback to the default system-wide temporary directory.
+ # This ignores user-defined environment variables.
+ #
+ # On POSIX systems, /tmp MUST be writable by any application [1].
+ # We however emit a warning if this is not the case to prevent
+ # obscure errors later in the execution.
+ #
+ # On some legacy systems, /var/tmp and /usr/tmp can be present
+ # and will be used instead.
+ #
+ # [1]: https://refspecs.linuxfoundation.org/FHS_3.0/fhs/ch03s18.html
+ dirlist = ['/tmp', '/var/tmp', '/usr/tmp']
+ try:
+ base_system_tempdir = tempfile._get_default_tempdir(dirlist)
+ except FileNotFoundError:
+ warn("Process-wide temporary directory %s will not be usable for "
+ "creating socket files and no usable system-wide temporary "
+ "directory was found in %s", base_tempdir, dirlist)
+ # At this point, the system-wide temporary directory is not usable
+ # but we may assume that the user-defined one is, even if we will
+ # not be able to write socket files out there.
+ return base_tempdir
+ warn("Ignoring user-defined temporary directory: %s", base_tempdir)
+ # at most max(map(len, dirlist)) + 14 + 14 = 36 characters
+ assert len(base_system_tempdir) + 14 + 14 <= _SUN_PATH_MAX
+ return base_system_tempdir
+
def get_temp_dir():
# get name of a temp directory which will be automatically cleaned up
tempdir = process.current_process()._config.get('tempdir')
if tempdir is None:
import shutil, tempfile
- tempdir = tempfile.mkdtemp(prefix='pymp-')
+ base_tempdir = _get_base_temp_dir(tempfile)
+ tempdir = tempfile.mkdtemp(prefix='pymp-', dir=base_tempdir)
info('created temp directory %s', tempdir)
# keep a strong reference to shutil.rmtree(), since the finalizer
# can be called late during Python shutdown
diff --git a/Lib/pdb.py b/Lib/pdb.py
index 544c701bbd2..78ee35f61bb 100644
--- a/Lib/pdb.py
+++ b/Lib/pdb.py
@@ -3489,7 +3489,8 @@ def help():
_usage = """\
Debug the Python program given by pyfile. Alternatively,
an executable module or package to debug can be specified using
-the -m switch.
+the -m switch. You can also attach to a running Python process
+using the -p option with its PID.
Initial commands are read from .pdbrc files in your home directory
and in the current directory, if they exist. Commands supplied with
diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index cadb0bed3cc..5e3ccab5f48 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -180,7 +180,7 @@ def _candidate_tempdir_list():
return dirlist
-def _get_default_tempdir():
+def _get_default_tempdir(dirlist=None):
"""Calculate the default directory to use for temporary files.
This routine should be called exactly once.
@@ -190,7 +190,8 @@ def _get_default_tempdir():
service, the name of the test file must be randomized."""
namer = _RandomNameSequence()
- dirlist = _candidate_tempdir_list()
+ if dirlist is None:
+ dirlist = _candidate_tempdir_list()
for dir in dirlist:
if dir != _os.curdir:
diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py
index edb734d294f..0af63501f93 100644
--- a/Lib/test/support/import_helper.py
+++ b/Lib/test/support/import_helper.py
@@ -438,5 +438,5 @@ def ensure_module_imported(name, *, clearnone=True):
if sys.modules.get(name) is not None:
mod = sys.modules[name]
else:
- mod, _, _ = _force_import(name, False, True, clearnone)
+ mod, _, _ = _ensure_module(name, False, True, clearnone)
return mod
diff --git a/Lib/test/test__interpreters.py b/Lib/test/test__interpreters.py
index 0c43f46300f..63fdaad8de7 100644
--- a/Lib/test/test__interpreters.py
+++ b/Lib/test/test__interpreters.py
@@ -1054,7 +1054,7 @@ class RunFuncTests(TestBase):
def script():
assert spam
- with self.assertRaises(ValueError):
+ with self.assertRaises(TypeError):
_interpreters.run_func(self.id, script)
# XXX This hasn't been fixed yet.
@@ -1065,6 +1065,7 @@ class RunFuncTests(TestBase):
with self.assertRaises(ValueError):
_interpreters.run_func(self.id, script)
+ @unittest.skip("we're not quite there yet")
def test_args(self):
with self.subTest('args'):
def script(a, b=0):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index de81936b745..aab6a779170 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -347,6 +347,18 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
selectors.EVENT_WRITE)])
self.loop._remove_writer.assert_called_with(1)
+ def test_accept_connection_zero_one(self):
+ for backlog in [0, 1]:
+ sock = mock.Mock()
+ sock.accept.return_value = (mock.Mock(), mock.Mock())
+ with self.subTest(backlog):
+ mock_obj = mock.patch.object
+ with mock_obj(self.loop, '_accept_connection2') as accept2_mock:
+ self.loop._accept_connection(
+ mock.Mock(), sock, backlog=backlog)
+ self.loop.run_until_complete(asyncio.sleep(0))
+ self.assertEqual(sock.accept.call_count, backlog + 1)
+
def test_accept_connection_multiple(self):
sock = mock.Mock()
sock.accept.return_value = (mock.Mock(), mock.Mock())
@@ -362,7 +374,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
self.loop._accept_connection(
mock.Mock(), sock, backlog=backlog)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
def test_accept_connection_skip_connectionabortederror(self):
sock = mock.Mock()
@@ -388,7 +400,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
# as in test_accept_connection_multiple avoid task pending
# warnings by using asyncio.sleep(0)
self.loop.run_until_complete(asyncio.sleep(0))
- self.assertEqual(sock.accept.call_count, backlog)
+ self.assertEqual(sock.accept.call_count, backlog + 1)
class SelectorTransportTests(test_utils.TestCase):
diff --git a/Lib/test/test_bufio.py b/Lib/test/test_bufio.py
index dc9a82dc635..cb9cb4d0bc7 100644
--- a/Lib/test/test_bufio.py
+++ b/Lib/test/test_bufio.py
@@ -28,7 +28,7 @@ class BufferSizeTest:
f.write(b"\n")
f.write(s)
f.close()
- f = open(os_helper.TESTFN, "rb")
+ f = self.open(os_helper.TESTFN, "rb")
line = f.readline()
self.assertEqual(line, s + b"\n")
line = f.readline()
diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py
index 50c4f19a1ab..98b434313e4 100644
--- a/Lib/test/test_capi/test_opt.py
+++ b/Lib/test/test_capi/test_opt.py
@@ -2137,6 +2137,25 @@ class TestUopsOptimization(unittest.TestCase):
self.assertNotIn("_TO_BOOL_BOOL", uops)
self.assertIn("_GUARD_IS_TRUE_POP", uops)
+ def test_set_type_version_sets_type(self):
+ class C:
+ A = 1
+
+ def testfunc(n):
+ x = 0
+ c = C()
+ for _ in range(n):
+ x += c.A # Guarded.
+ x += type(c).A # Unguarded!
+ return x
+
+ res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
+ self.assertEqual(res, 2 * TIER2_THRESHOLD)
+ self.assertIsNotNone(ex)
+ uops = get_opnames(ex)
+ self.assertIn("_GUARD_TYPE_VERSION", uops)
+ self.assertNotIn("_CHECK_ATTR_CLASS", uops)
+
def global_identity(x):
return x
diff --git a/Lib/test/test_crossinterp.py b/Lib/test/test_crossinterp.py
index cddacbc9970..c54635eaeab 100644
--- a/Lib/test/test_crossinterp.py
+++ b/Lib/test/test_crossinterp.py
@@ -5,6 +5,7 @@ import itertools
import sys
import types
import unittest
+import warnings
from test.support import import_helper
@@ -16,13 +17,281 @@ from test import _code_definitions as code_defs
from test import _crossinterp_definitions as defs
-BUILTIN_TYPES = [o for _, o in __builtins__.items()
- if isinstance(o, type)]
-EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES
+@contextlib.contextmanager
+def ignore_byteswarning():
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore', category=BytesWarning)
+ yield
+
+
+# builtin types
+
+BUILTINS_TYPES = [o for _, o in __builtins__.items() if isinstance(o, type)]
+EXCEPTION_TYPES = [cls for cls in BUILTINS_TYPES
if issubclass(cls, BaseException)]
OTHER_TYPES = [o for n, o in vars(types).items()
if (isinstance(o, type) and
- n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+ n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
+BUILTIN_TYPES = [
+ *BUILTINS_TYPES,
+ *OTHER_TYPES,
+]
+
+# builtin exceptions
+
+try:
+ raise Exception
+except Exception as exc:
+ CAUGHT = exc
+EXCEPTIONS_WITH_SPECIAL_SIG = {
+ BaseExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ ExceptionGroup: (lambda msg: (msg, [CAUGHT])),
+ UnicodeError: (lambda msg: (None, msg, None, None, None)),
+ UnicodeEncodeError: (lambda msg: ('utf-8', '', 1, 3, msg)),
+ UnicodeDecodeError: (lambda msg: ('utf-8', b'', 1, 3, msg)),
+ UnicodeTranslateError: (lambda msg: ('', 1, 3, msg)),
+}
+BUILTIN_EXCEPTIONS = [
+ *(cls(*sig('error!')) for cls, sig in EXCEPTIONS_WITH_SPECIAL_SIG.items()),
+ *(cls('error!') for cls in EXCEPTION_TYPES
+ if cls not in EXCEPTIONS_WITH_SPECIAL_SIG),
+]
+
+# other builtin objects
+
+METHOD = defs.SpamOkay().okay
+BUILTIN_METHOD = [].append
+METHOD_DESCRIPTOR_WRAPPER = str.join
+METHOD_WRAPPER = object().__str__
+WRAPPER_DESCRIPTOR = object.__init__
+BUILTIN_WRAPPERS = {
+ METHOD: types.MethodType,
+ BUILTIN_METHOD: types.BuiltinMethodType,
+ dict.__dict__['fromkeys']: types.ClassMethodDescriptorType,
+ types.FunctionType.__code__: types.GetSetDescriptorType,
+ types.FunctionType.__globals__: types.MemberDescriptorType,
+ METHOD_DESCRIPTOR_WRAPPER: types.MethodDescriptorType,
+ METHOD_WRAPPER: types.MethodWrapperType,
+ WRAPPER_DESCRIPTOR: types.WrapperDescriptorType,
+ staticmethod(defs.SpamOkay.okay): None,
+ classmethod(defs.SpamOkay.okay): None,
+ property(defs.SpamOkay.okay): None,
+}
+BUILTIN_FUNCTIONS = [
+ # types.BuiltinFunctionType
+ len,
+ sys.is_finalizing,
+ sys.exit,
+ _testinternalcapi.get_crossinterp_data,
+]
+assert 'emptymod' not in sys.modules
+with import_helper.ready_to_import('emptymod', ''):
+ import emptymod as EMPTYMOD
+MODULES = [
+ sys,
+ defs,
+ unittest,
+ EMPTYMOD,
+]
+OBJECT = object()
+EXCEPTION = Exception()
+LAMBDA = (lambda: None)
+BUILTIN_SIMPLE = [
+ OBJECT,
+ # singletons
+ None,
+ True,
+ False,
+ Ellipsis,
+ NotImplemented,
+ # bytes
+ *(i.to_bytes(2, 'little', signed=True)
+ for i in range(-1, 258)),
+ # str
+ 'hello world',
+ '你好世界',
+ '',
+ # int
+ sys.maxsize + 1,
+ sys.maxsize,
+ -sys.maxsize - 1,
+ -sys.maxsize - 2,
+ *range(-1, 258),
+ 2**1000,
+ # float
+ 0.0,
+ 1.1,
+ -1.0,
+ 0.12345678,
+ -0.12345678,
+]
+TUPLE_EXCEPTION = (0, 1.0, EXCEPTION)
+TUPLE_OBJECT = (0, 1.0, OBJECT)
+TUPLE_NESTED_EXCEPTION = (0, 1.0, (EXCEPTION,))
+TUPLE_NESTED_OBJECT = (0, 1.0, (OBJECT,))
+MEMORYVIEW_EMPTY = memoryview(b'')
+MEMORYVIEW_NOT_EMPTY = memoryview(b'spam'*42)
+MAPPING_PROXY_EMPTY = types.MappingProxyType({})
+BUILTIN_CONTAINERS = [
+ # tuple (flat)
+ (),
+ (1,),
+ ("hello", "world", ),
+ (1, True, "hello"),
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ # tuple (nested)
+ ((1,),),
+ ((1, 2), (3, 4)),
+ ((1, 2), (3, 4), (5, 6)),
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+ # buffer
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ # list
+ [],
+ [1, 2, 3],
+ [[1], (2,), {3: 4}],
+ # dict
+ {},
+ {1: 7, 2: 8, 3: 9},
+ {1: [1], 2: (2,), 3: {3: 4}},
+ # set
+ set(),
+ {1, 2, 3},
+ {frozenset({1}), (2,)},
+ # frozenset
+ frozenset([]),
+ frozenset({frozenset({1}), (2,)}),
+ # bytearray
+ bytearray(b''),
+ # other
+ MAPPING_PROXY_EMPTY,
+ types.SimpleNamespace(),
+]
+ns = {}
+exec("""
+try:
+ raise Exception
+except Exception as exc:
+ TRACEBACK = exc.__traceback__
+ FRAME = TRACEBACK.tb_frame
+""", ns, ns)
+BUILTIN_OTHER = [
+ # types.CellType
+ types.CellType(),
+ # types.FrameType
+ ns['FRAME'],
+ # types.TracebackType
+ ns['TRACEBACK'],
+]
+del ns
+
+# user-defined objects
+
+USER_TOP_INSTANCES = [c(*a) for c, a in defs.TOP_CLASSES.items()]
+USER_NESTED_INSTANCES = [c(*a) for c, a in defs.NESTED_CLASSES.items()]
+USER_INSTANCES = [
+ *USER_TOP_INSTANCES,
+ *USER_NESTED_INSTANCES,
+]
+USER_EXCEPTIONS = [
+ defs.MimimalError('error!'),
+]
+
+# shareable objects
+
+TUPLES_WITHOUT_EQUALITY = [
+ TUPLE_EXCEPTION,
+ TUPLE_OBJECT,
+ TUPLE_NESTED_EXCEPTION,
+ TUPLE_NESTED_OBJECT,
+]
+_UNSHAREABLE_SIMPLE = [
+ Ellipsis,
+ NotImplemented,
+ OBJECT,
+ sys.maxsize + 1,
+ -sys.maxsize - 2,
+ 2**1000,
+]
+with ignore_byteswarning():
+ _SHAREABLE_SIMPLE = [o for o in BUILTIN_SIMPLE
+ if o not in _UNSHAREABLE_SIMPLE]
+ _SHAREABLE_CONTAINERS = [
+ *(o for o in BUILTIN_CONTAINERS if type(o) is memoryview),
+ *(o for o in BUILTIN_CONTAINERS
+ if type(o) is tuple and o not in TUPLES_WITHOUT_EQUALITY),
+ ]
+ _UNSHAREABLE_CONTAINERS = [o for o in BUILTIN_CONTAINERS
+ if o not in _SHAREABLE_CONTAINERS]
+SHAREABLE = [
+ *_SHAREABLE_SIMPLE,
+ *_SHAREABLE_CONTAINERS,
+]
+NOT_SHAREABLE = [
+ *_UNSHAREABLE_SIMPLE,
+ *_UNSHAREABLE_CONTAINERS,
+ *BUILTIN_TYPES,
+ *BUILTIN_WRAPPERS,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *MODULES,
+ *BUILTIN_OTHER,
+ # types.CodeType
+ *(f.__code__ for f in defs.FUNCTIONS),
+ *(f.__code__ for f in defs.FUNCTION_LIKE),
+ # types.FunctionType
+ *defs.FUNCTIONS,
+ defs.SpamOkay.okay,
+ LAMBDA,
+ *defs.FUNCTION_LIKE,
+ # coroutines and generators
+ *defs.FUNCTION_LIKE_APPLIED,
+ # user classes
+ *defs.CLASSES,
+ *USER_INSTANCES,
+ # user exceptions
+ *USER_EXCEPTIONS,
+]
+
+# pickleable objects
+
+PICKLEABLE = [
+ *BUILTIN_SIMPLE,
+ *(o for o in BUILTIN_CONTAINERS if o not in [
+ MEMORYVIEW_EMPTY,
+ MEMORYVIEW_NOT_EMPTY,
+ MAPPING_PROXY_EMPTY,
+ ] or type(o) is dict),
+ *BUILTINS_TYPES,
+ *BUILTIN_EXCEPTIONS,
+ *BUILTIN_FUNCTIONS,
+ *defs.TOP_FUNCTIONS,
+ defs.SpamOkay.okay,
+ *defs.FUNCTION_LIKE,
+ *defs.TOP_CLASSES,
+ *USER_TOP_INSTANCES,
+ *USER_EXCEPTIONS,
+ # from OTHER_TYPES
+ types.NoneType,
+ types.EllipsisType,
+ types.NotImplementedType,
+ types.GenericAlias,
+ types.UnionType,
+ types.SimpleNamespace,
+ # from BUILTIN_WRAPPERS
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_DESCRIPTOR_WRAPPER,
+ METHOD_WRAPPER,
+ WRAPPER_DESCRIPTOR,
+]
+assert not any(isinstance(o, types.MappingProxyType) for o in PICKLEABLE)
+
+
+# helpers
DEFS = defs
with open(code_defs.__file__) as infile:
@@ -111,6 +380,77 @@ class _GetXIDataTests(unittest.TestCase):
MODE = None
+ def assert_functions_equal(self, func1, func2):
+ assert type(func1) is types.FunctionType, repr(func1)
+ assert type(func2) is types.FunctionType, repr(func2)
+ self.assertEqual(func1.__name__, func2.__name__)
+ self.assertEqual(func1.__code__, func2.__code__)
+ self.assertEqual(func1.__defaults__, func2.__defaults__)
+ self.assertEqual(func1.__kwdefaults__, func2.__kwdefaults__)
+ # We don't worry about __globals__ for now.
+
+ def assert_exc_args_equal(self, exc1, exc2):
+ args1 = exc1.args
+ args2 = exc2.args
+ if isinstance(exc1, ExceptionGroup):
+ self.assertIs(type(args1), type(args2))
+ self.assertEqual(len(args1), 2)
+ self.assertEqual(len(args1), len(args2))
+ self.assertEqual(args1[0], args2[0])
+ group1 = args1[1]
+ group2 = args2[1]
+ self.assertEqual(len(group1), len(group2))
+ for grouped1, grouped2 in zip(group1, group2):
+ # Currently the "extra" attrs are not preserved
+ # (via __reduce__).
+ self.assertIs(type(exc1), type(exc2))
+ self.assert_exc_equal(grouped1, grouped2)
+ else:
+ self.assertEqual(args1, args2)
+
+ def assert_exc_equal(self, exc1, exc2):
+ self.assertIs(type(exc1), type(exc2))
+
+ if type(exc1).__eq__ is not object.__eq__:
+ self.assertEqual(exc1, exc2)
+
+ self.assert_exc_args_equal(exc1, exc2)
+ # XXX For now we do not preserve tracebacks.
+ if exc1.__traceback__ is not None:
+ self.assertEqual(exc1.__traceback__, exc2.__traceback__)
+ self.assertEqual(
+ getattr(exc1, '__notes__', None),
+ getattr(exc2, '__notes__', None),
+ )
+ # We assume there are no cycles.
+ if exc1.__cause__ is None:
+ self.assertIs(exc1.__cause__, exc2.__cause__)
+ else:
+ self.assert_exc_equal(exc1.__cause__, exc2.__cause__)
+ if exc1.__context__ is None:
+ self.assertIs(exc1.__context__, exc2.__context__)
+ else:
+ self.assert_exc_equal(exc1.__context__, exc2.__context__)
+
+ def assert_equal_or_equalish(self, obj, expected):
+ cls = type(expected)
+ if cls.__eq__ is not object.__eq__:
+ self.assertEqual(obj, expected)
+ elif cls is types.FunctionType:
+ self.assert_functions_equal(obj, expected)
+ elif isinstance(expected, BaseException):
+ self.assert_exc_equal(obj, expected)
+ elif cls is types.MethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.BuiltinMethodType:
+ raise NotImplementedError(cls)
+ elif cls is types.MethodWrapperType:
+ raise NotImplementedError(cls)
+ elif cls.__bases__ == (object,):
+ self.assertEqual(obj.__dict__, expected.__dict__)
+ else:
+ raise NotImplementedError(cls)
+
def get_xidata(self, obj, *, mode=None):
mode = self._resolve_mode(mode)
return _testinternalcapi.get_crossinterp_data(obj, mode)
@@ -126,35 +466,37 @@ class _GetXIDataTests(unittest.TestCase):
def assert_roundtrip_identical(self, values, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIs(got, obj)
def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
- self.assertEqual(got, obj)
+ if got is obj:
+ continue
self.assertIs(type(got),
type(obj) if expecttype is None else expecttype)
+ self.assert_equal_or_equalish(got, obj)
def assert_roundtrip_equal_not_identical(self, values, *,
mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIsNot(got, obj)
self.assertIs(type(got),
type(obj) if expecttype is None else expecttype)
- self.assertEqual(got, obj)
+ self.assert_equal_or_equalish(got, obj)
def assert_roundtrip_not_equal(self, values, *,
mode=None, expecttype=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
got = self._get_roundtrip(obj, mode)
self.assertIsNot(got, obj)
self.assertIs(type(got),
@@ -164,7 +506,7 @@ class _GetXIDataTests(unittest.TestCase):
def assert_not_shareable(self, values, exctype=None, *, mode=None):
mode = self._resolve_mode(mode)
for obj in values:
- with self.subTest(obj):
+ with self.subTest(repr(obj)):
with self.assertRaises(NotShareableError) as cm:
_testinternalcapi.get_crossinterp_data(obj, mode)
if exctype is not None:
@@ -182,49 +524,26 @@ class PickleTests(_GetXIDataTests):
MODE = 'pickle'
def test_shareable(self):
- self.assert_roundtrip_equal([
- # singletons
- None,
- True,
- False,
- # bytes
- *(i.to_bytes(2, 'little', signed=True)
- for i in range(-1, 258)),
- # str
- 'hello world',
- '你好世界',
- '',
- # int
- sys.maxsize,
- -sys.maxsize - 1,
- *range(-1, 258),
- # float
- 0.0,
- 1.1,
- -1.0,
- 0.12345678,
- -0.12345678,
- # tuple
- (),
- (1,),
- ("hello", "world", ),
- (1, True, "hello"),
- ((1,),),
- ((1, 2), (3, 4)),
- ((1, 2), (3, 4), (5, 6)),
- ])
- # not shareable using xidata
- self.assert_roundtrip_equal([
- # int
- sys.maxsize + 1,
- -sys.maxsize - 2,
- 2**1000,
- # tuple
- (0, 1.0, []),
- (0, 1.0, {}),
- (0, 1.0, ([],)),
- (0, 1.0, ({},)),
- ])
+ with ignore_byteswarning():
+ for obj in SHAREABLE:
+ if obj in PICKLEABLE:
+ self.assert_roundtrip_equal([obj])
+ else:
+ self.assert_not_shareable([obj])
+
+ def test_not_shareable(self):
+ with ignore_byteswarning():
+ for obj in NOT_SHAREABLE:
+ if type(obj) is types.MappingProxyType:
+ self.assert_not_shareable([obj])
+ elif obj in PICKLEABLE:
+ with self.subTest(repr(obj)):
+ # We don't worry about checking the actual value.
+ # The other tests should cover that well enough.
+ got = self.get_roundtrip(obj)
+ self.assertIs(type(got), type(obj))
+ else:
+ self.assert_not_shareable([obj])
def test_list(self):
self.assert_roundtrip_equal_not_identical([
@@ -266,7 +585,7 @@ class PickleTests(_GetXIDataTests):
if cls not in defs.CLASSES_WITHOUT_EQUALITY:
continue
instances.append(cls(*args))
- self.assert_roundtrip_not_equal(instances)
+ self.assert_roundtrip_equal(instances)
def assert_class_defs_other_pickle(self, defs, mod):
# Pickle relative to a different module than the original.
@@ -286,7 +605,7 @@ class PickleTests(_GetXIDataTests):
instances = []
for cls, args in defs.TOP_CLASSES.items():
- with self.subTest(cls):
+ with self.subTest(repr(cls)):
setattr(mod, cls.__name__, cls)
xid = self.get_xidata(cls)
inst = cls(*args)
@@ -295,7 +614,7 @@ class PickleTests(_GetXIDataTests):
(cls, xid, inst, instxid))
for cls, xid, inst, instxid in instances:
- with self.subTest(cls):
+ with self.subTest(repr(cls)):
delattr(mod, cls.__name__)
if fail:
with self.assertRaises(NotShareableError):
@@ -403,13 +722,13 @@ class PickleTests(_GetXIDataTests):
def assert_func_defs_other_pickle(self, defs, mod):
# Pickle relative to a different module than the original.
for func in defs.TOP_FUNCTIONS:
- assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
self.assert_not_shareable(defs.TOP_FUNCTIONS)
def assert_func_defs_other_unpickle(self, defs, mod, *, fail=False):
# Unpickle relative to a different module than the original.
for func in defs.TOP_FUNCTIONS:
- assert not hasattr(mod, func.__name__), (cls, getattr(mod, func.__name__))
+ assert not hasattr(mod, func.__name__), (getattr(mod, func.__name__),)
captured = []
for func in defs.TOP_FUNCTIONS:
@@ -434,7 +753,7 @@ class PickleTests(_GetXIDataTests):
self.assert_not_shareable(defs.TOP_FUNCTIONS)
def test_user_function_normal(self):
-# self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
+ self.assert_roundtrip_equal(defs.TOP_FUNCTIONS)
self.assert_func_defs_same(defs)
def test_user_func_in___main__(self):
@@ -505,7 +824,7 @@ class PickleTests(_GetXIDataTests):
# exceptions
def test_user_exception_normal(self):
- self.assert_roundtrip_not_equal([
+ self.assert_roundtrip_equal([
defs.MimimalError('error!'),
])
self.assert_roundtrip_equal_not_identical([
@@ -521,7 +840,7 @@ class PickleTests(_GetXIDataTests):
special = {
BaseExceptionGroup: (msg, [caught]),
ExceptionGroup: (msg, [caught]),
-# UnicodeError: (None, msg, None, None, None),
+ UnicodeError: (None, msg, None, None, None),
UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
UnicodeTranslateError: ('', 1, 3, msg),
@@ -531,7 +850,7 @@ class PickleTests(_GetXIDataTests):
args = special.get(cls) or (msg,)
exceptions.append(cls(*args))
- self.assert_roundtrip_not_equal(exceptions)
+ self.assert_roundtrip_equal(exceptions)
class MarshalTests(_GetXIDataTests):
@@ -576,7 +895,7 @@ class MarshalTests(_GetXIDataTests):
'',
])
self.assert_not_shareable([
- object(),
+ OBJECT,
types.SimpleNamespace(),
])
@@ -647,10 +966,7 @@ class MarshalTests(_GetXIDataTests):
shareable = [
StopIteration,
]
- types = [
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ]
+ types = BUILTIN_TYPES
self.assert_not_shareable(cls for cls in types
if cls not in shareable)
self.assert_roundtrip_identical(cls for cls in types
@@ -763,7 +1079,7 @@ class ShareableFuncTests(_GetXIDataTests):
MODE = 'func'
def test_stateless(self):
- self.assert_roundtrip_not_equal([
+ self.assert_roundtrip_equal([
*defs.STATELESS_FUNCTIONS,
# Generators can be stateless too.
*defs.FUNCTION_LIKE,
@@ -912,10 +1228,49 @@ class ShareableScriptTests(PureShareableScriptTests):
], expecttype=types.CodeType)
+class ShareableFallbackTests(_GetXIDataTests):
+
+ MODE = 'fallback'
+
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
+ def test_not_shareable(self):
+ okay = [
+ *PICKLEABLE,
+ *defs.STATELESS_FUNCTIONS,
+ LAMBDA,
+ ]
+ ignored = [
+ *TUPLES_WITHOUT_EQUALITY,
+ OBJECT,
+ METHOD,
+ BUILTIN_METHOD,
+ METHOD_WRAPPER,
+ ]
+ with ignore_byteswarning():
+ self.assert_roundtrip_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in okay and o not in ignored
+ and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_roundtrip_not_equal([
+ *(o for o in NOT_SHAREABLE
+ if o in ignored and o is not MAPPING_PROXY_EMPTY),
+ ])
+ self.assert_not_shareable([
+ *(o for o in NOT_SHAREABLE if o not in okay),
+ MAPPING_PROXY_EMPTY,
+ ])
+
+
class ShareableTypeTests(_GetXIDataTests):
MODE = 'xidata'
+ def test_shareable(self):
+ self.assert_roundtrip_equal(SHAREABLE)
+
def test_singletons(self):
self.assert_roundtrip_identical([
None,
@@ -983,8 +1338,8 @@ class ShareableTypeTests(_GetXIDataTests):
def test_tuples_containing_non_shareable_types(self):
non_shareables = [
- Exception(),
- object(),
+ EXCEPTION,
+ OBJECT,
]
for s in non_shareables:
value = tuple([0, 1.0, s])
@@ -999,6 +1354,9 @@ class ShareableTypeTests(_GetXIDataTests):
# The rest are not shareable.
+ def test_not_shareable(self):
+ self.assert_not_shareable(NOT_SHAREABLE)
+
def test_object(self):
self.assert_not_shareable([
object(),
@@ -1015,12 +1373,12 @@ class ShareableTypeTests(_GetXIDataTests):
for func in defs.FUNCTIONS:
assert type(func) is types.FunctionType, func
assert type(defs.SpamOkay.okay) is types.FunctionType, func
- assert type(lambda: None) is types.LambdaType
+ assert type(LAMBDA) is types.LambdaType
self.assert_not_shareable([
*defs.FUNCTIONS,
defs.SpamOkay.okay,
- (lambda: None),
+ LAMBDA,
])
def test_builtin_function(self):
@@ -1085,10 +1443,7 @@ class ShareableTypeTests(_GetXIDataTests):
self.assert_not_shareable(instances)
def test_builtin_type(self):
- self.assert_not_shareable([
- *BUILTIN_TYPES,
- *OTHER_TYPES,
- ])
+ self.assert_not_shareable(BUILTIN_TYPES)
def test_exception(self):
self.assert_not_shareable([
@@ -1127,7 +1482,7 @@ class ShareableTypeTests(_GetXIDataTests):
""", ns, ns)
self.assert_not_shareable([
- types.MappingProxyType({}),
+ MAPPING_PROXY_EMPTY,
types.SimpleNamespace(),
# types.CellType
types.CellType(),
diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py
index 46222e521ae..89f4aebe28f 100644
--- a/Lib/test/test_embed.py
+++ b/Lib/test/test_embed.py
@@ -1915,6 +1915,10 @@ class AuditingTests(EmbeddingTestsMixin, unittest.TestCase):
self.run_embedded_interpreter("test_get_incomplete_frame")
+ def test_gilstate_after_finalization(self):
+ self.run_embedded_interpreter("test_gilstate_after_finalization")
+
+
class MiscTests(EmbeddingTestsMixin, unittest.TestCase):
def test_unicode_id_init(self):
# bpo-42882: Test that _PyUnicode_FromId() works
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 90680c6d47a..aa619a96ab7 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -918,7 +918,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -1
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -1')
def test_bad_opener_other_negative(self):
@@ -926,7 +926,7 @@ class IOTest(unittest.TestCase):
def badopener(fname, flags):
return -2
with self.assertRaises(ValueError) as cm:
- open('non-existent', 'r', opener=badopener)
+ self.open('non-existent', 'r', opener=badopener)
self.assertEqual(str(cm.exception), 'opener returned -2')
def test_opener_invalid_fd(self):
@@ -4417,7 +4417,7 @@ class MiscIOTest(unittest.TestCase):
self._check_abc_inheritance(io)
def _check_warn_on_dealloc(self, *args, **kwargs):
- f = open(*args, **kwargs)
+ f = self.open(*args, **kwargs)
r = repr(f)
with self.assertWarns(ResourceWarning) as cm:
f = None
@@ -4446,7 +4446,7 @@ class MiscIOTest(unittest.TestCase):
r, w = os.pipe()
fds += r, w
with warnings_helper.check_no_resource_warning(self):
- open(r, *args, closefd=False, **kwargs)
+ self.open(r, *args, closefd=False, **kwargs)
@unittest.skipUnless(hasattr(os, "pipe"), "requires os.pipe()")
def test_warn_on_dealloc_fd(self):
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index a9eec139bec..0e51e7fc8c5 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -1353,6 +1353,7 @@ class ThreadTests(BaseTestCase):
assert_python_ok("-c", script)
@skip_unless_reliable_fork
+ @unittest.skipUnless(hasattr(threading, 'get_native_id'), "test needs threading.get_native_id()")
def test_native_id_after_fork(self):
script = """if True:
import threading
diff --git a/Lib/test/test_zipfile/test_core.py b/Lib/test/test_zipfile/test_core.py
index 43056978848..e93603998f9 100644
--- a/Lib/test/test_zipfile/test_core.py
+++ b/Lib/test/test_zipfile/test_core.py
@@ -1991,6 +1991,25 @@ class OtherTests(unittest.TestCase):
self.assertFalse(zipfile.is_zipfile(fp))
fp.seek(0, 0)
self.assertFalse(zipfile.is_zipfile(fp))
+ # - passing non-zipfile with ZIP header elements
+ # data created using pyPNG like so:
+ # d = [(ord('P'), ord('K'), 5, 6), (ord('P'), ord('K'), 6, 6)]
+ # w = png.Writer(1,2,alpha=True,compression=0)
+ # f = open('onepix.png', 'wb')
+ # w.write(f, d)
+ # w.close()
+ data = (b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
+ b"\x00\x02\x08\x06\x00\x00\x00\x99\x81\xb6'\x00\x00\x00\x15I"
+ b"DATx\x01\x01\n\x00\xf5\xff\x00PK\x05\x06\x00PK\x06\x06\x07"
+ b"\xac\x01N\xc6|a\r\x00\x00\x00\x00IEND\xaeB`\x82")
+ # - passing a filename
+ with open(TESTFN, "wb") as fp:
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(TESTFN))
+ # - passing a file-like object
+ fp = io.BytesIO()
+ fp.write(data)
+ self.assertFalse(zipfile.is_zipfile(fp))
def test_damaged_zipfile(self):
"""Check that zipfiles with missing bytes at the end raise BadZipFile."""
diff --git a/Lib/zipfile/__init__.py b/Lib/zipfile/__init__.py
index 894b4d37233..18caeb3e04a 100644
--- a/Lib/zipfile/__init__.py
+++ b/Lib/zipfile/__init__.py
@@ -234,8 +234,19 @@ class _Extra(bytes):
def _check_zipfile(fp):
try:
- if _EndRecData(fp):
- return True # file has correct magic number
+ endrec = _EndRecData(fp)
+ if endrec:
+ if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0:
+ return True # Empty zipfiles are still zipfiles
+ elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]:
+ # Central directory is on the same disk
+ fp.seek(sum(_handle_prepended_data(endrec)))
+ if endrec[_ECD_SIZE] >= sizeCentralDir:
+ data = fp.read(sizeCentralDir) # CD is where we expect it to be
+ if len(data) == sizeCentralDir:
+ centdir = struct.unpack(structCentralDir, data) # CD is the right size
+ if centdir[_CD_SIGNATURE] == stringCentralDir:
+ return True # First central directory entry has correct magic number
except OSError:
pass
return False
@@ -258,6 +269,22 @@ def is_zipfile(filename):
pass
return result
+def _handle_prepended_data(endrec, debug=0):
+ size_cd = endrec[_ECD_SIZE] # bytes in central directory
+ offset_cd = endrec[_ECD_OFFSET] # offset of central directory
+
+ # "concat" is zero, unless zip was concatenated to another file
+ concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
+ if endrec[_ECD_SIGNATURE] == stringEndArchive64:
+ # If Zip64 extension structures are present, account for them
+ concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+
+ if debug > 2:
+ inferred = concat + offset_cd
+ print("given, inferred, offset", offset_cd, inferred, concat)
+
+ return offset_cd, concat
+
def _EndRecData64(fpin, offset, endrec):
"""
Read the ZIP64 end-of-archive records and use that to update endrec
@@ -1501,28 +1528,21 @@ class ZipFile:
raise BadZipFile("File is not a zip file")
if self.debug > 1:
print(endrec)
- size_cd = endrec[_ECD_SIZE] # bytes in central directory
- offset_cd = endrec[_ECD_OFFSET] # offset of central directory
self._comment = endrec[_ECD_COMMENT] # archive comment
- # "concat" is zero, unless zip was concatenated to another file
- concat = endrec[_ECD_LOCATION] - size_cd - offset_cd
- if endrec[_ECD_SIGNATURE] == stringEndArchive64:
- # If Zip64 extension structures are present, account for them
- concat -= (sizeEndCentDir64 + sizeEndCentDir64Locator)
+ offset_cd, concat = _handle_prepended_data(endrec, self.debug)
+
+ # self.start_dir: Position of start of central directory
+ self.start_dir = offset_cd + concat
# store the offset to the beginning of data for the
# .data_offset property
self._data_offset = concat
- if self.debug > 2:
- inferred = concat + offset_cd
- print("given, inferred, offset", offset_cd, inferred, concat)
- # self.start_dir: Position of start of central directory
- self.start_dir = offset_cd + concat
if self.start_dir < 0:
raise BadZipFile("Bad offset for central directory")
fp.seek(self.start_dir, 0)
+ size_cd = endrec[_ECD_SIZE]
data = fp.read(size_cd)
fp = io.BytesIO(data)
total = 0