Add custom stream type for compressed resources

This commit is contained in:
dgelessus 2020-08-01 14:00:13 +02:00
parent 8d39469e6e
commit 9e6dfacff6
2 changed files with 93 additions and 4 deletions

View File

@ -252,9 +252,8 @@ class Resource(object):
try: try:
return self._data_decompressed return self._data_decompressed
except AttributeError: except AttributeError:
with self.open_raw() as f: with self.open() as f:
f.seek(self.compressed_info.header_length) self._data_decompressed = f.read()
self._data_decompressed = b"".join(compress.decompress_stream_parsed(self.compressed_info, f))
return self._data_decompressed return self._data_decompressed
else: else:
return self.data_raw return self.data_raw
@ -272,7 +271,12 @@ class Resource(object):
because the stream API does not require the entire resource data to be read (and possibly decompressed) in advance. because the stream API does not require the entire resource data to be read (and possibly decompressed) in advance.
""" """
return io.BytesIO(self.data) if self.compressed_info is None:
return self.open_raw()
else:
f = self.open_raw()
f.seek(self.compressed_info.header_length)
return compress.DecompressingStream(f, self.compressed_info, close_stream=True)
class _LazyResourceMap(typing.Mapping[int, Resource]): class _LazyResourceMap(typing.Mapping[int, Resource]):

View File

@ -66,3 +66,88 @@ def decompress(data: bytes, *, debug: bool = False) -> bytes:
"""Decompress the given compressed resource data.""" """Decompress the given compressed resource data."""
return b"".join(decompress_stream(io.BytesIO(data), debug=debug)) return b"".join(decompress_stream(io.BytesIO(data), debug=debug))
class DecompressingStream(io.BufferedIOBase, typing.BinaryIO):
_compressed_stream: typing.BinaryIO
_close_stream: bool
_header_info: CompressedHeaderInfo
_decompress_iter: typing.Iterator[bytes]
_decompressed_stream: typing.BinaryIO
_seek_position: int
def __init__(self, compressed_stream: typing.BinaryIO, header_info: typing.Optional[CompressedHeaderInfo], *, close_stream: bool = False) -> None:
super().__init__()
self._compressed_stream = compressed_stream
self._close_stream = close_stream
if header_info is not None:
self._header_info = header_info
else:
self._header_info = CompressedHeaderInfo.parse_stream(self._compressed_stream)
self._decompress_iter = decompress_stream_parsed(self._header_info, self._compressed_stream)
self._decompressed_stream = io.BytesIO()
self._seek_position = 0
# This override does nothing,
# but is needed to make mypy happy,
# otherwise it complains (apparently incorrectly) about the __enter__ definitions from IOBase and BinaryIO being incompatible with each other.
def __enter__(self: "DecompressingStream") -> "DecompressingStream":
return super().__enter__()
def close(self) -> None:
super().close()
if self._close_stream:
self._compressed_stream.close()
del self._decompress_iter
self._decompressed_stream.close()
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._seek_position
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
if offset < 0:
raise ValueError(f"Negative seek offset not allowed with SEEK_SET: {offset}")
self._seek_position = offset
elif whence == io.SEEK_CUR:
self._seek_position += offset
elif whence == io.SEEK_END:
self._seek_position = self._header_info.decompressed_length - offset
else:
raise ValueError(f"Invalid whence value: {whence}")
self._seek_position = max(0, min(self._header_info.decompressed_length, self._seek_position))
return self._seek_position
def readable(self) -> bool:
return True
def read(self, size: typing.Optional[int] = -1) -> bytes:
if size is None:
size = -1
self._decompressed_stream.seek(0, io.SEEK_END)
if size < 0:
for chunk in self._decompress_iter:
self._decompressed_stream.write(chunk)
else:
if self._decompressed_stream.tell() - self._seek_position < size:
for chunk in self._decompress_iter:
self._decompressed_stream.write(chunk)
if self._decompressed_stream.tell() - self._seek_position >= size:
break
self._decompressed_stream.seek(self._seek_position)
ret = self._decompressed_stream.read(size)
self._seek_position += len(ret)
return ret