154 lines
5.0 KiB
Python
154 lines
5.0 KiB
Python
import io
|
|
import typing
|
|
|
|
from . import dcmp0
|
|
from . import dcmp1
|
|
from . import dcmp2
|
|
|
|
from .common import DecompressError, CompressedHeaderInfo, CompressedType8HeaderInfo, CompressedType9HeaderInfo
|
|
|
|
__all__ = [
|
|
"CompressedHeaderInfo",
|
|
"CompressedType8HeaderInfo",
|
|
"CompressedType9HeaderInfo",
|
|
"DecompressError",
|
|
"decompress",
|
|
"decompress_parsed",
|
|
"decompress_stream",
|
|
"decompress_stream_parsed",
|
|
]
|
|
|
|
|
|
# Maps 'dcmp' IDs to their corresponding Python implementations.
|
|
# Each decompressor has the signature (header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes].
|
|
DECOMPRESSORS = {
|
|
0: dcmp0.decompress_stream,
|
|
1: dcmp1.decompress_stream,
|
|
2: dcmp2.decompress_stream,
|
|
}
|
|
|
|
|
|
def decompress_stream_parsed(header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool = False) -> typing.Iterator[bytes]:
|
|
"""Decompress compressed resource data from a stream, whose header has already been read and parsed into a CompressedHeaderInfo object."""
|
|
|
|
try:
|
|
decompress_func = DECOMPRESSORS[header_info.dcmp_id]
|
|
except KeyError:
|
|
raise DecompressError(f"Unsupported 'dcmp' ID: {header_info.dcmp_id}")
|
|
|
|
decompressed_length = 0
|
|
for chunk in decompress_func(header_info, stream, debug=debug):
|
|
decompressed_length += len(chunk)
|
|
yield chunk
|
|
|
|
if decompressed_length != header_info.decompressed_length:
|
|
raise DecompressError(f"Actual length of decompressed data ({decompressed_length}) does not match length stored in resource ({header_info.decompressed_length})")
|
|
|
|
|
|
def decompress_parsed(header_info: CompressedHeaderInfo, data: bytes, *, debug: bool = False) -> bytes:
|
|
"""Decompress the given compressed resource data, whose header has already been removed and parsed into a CompressedHeaderInfo object."""
|
|
|
|
return b"".join(decompress_stream_parsed(header_info, io.BytesIO(data), debug=debug))
|
|
|
|
|
|
def decompress_stream(stream: typing.BinaryIO, *, debug: bool = False) -> typing.Iterator[bytes]:
|
|
"""Decompress compressed resource data from a stream."""
|
|
|
|
header_info = CompressedHeaderInfo.parse_stream(stream)
|
|
|
|
if debug:
|
|
print(f"Compressed resource data header: {header_info}")
|
|
|
|
yield from decompress_stream_parsed(header_info, stream, debug=debug)
|
|
|
|
|
|
def decompress(data: bytes, *, debug: bool = False) -> bytes:
|
|
"""Decompress the given compressed resource data."""
|
|
|
|
return b"".join(decompress_stream(io.BytesIO(data), debug=debug))
|
|
|
|
|
|
class DecompressingStream(io.BufferedIOBase, typing.BinaryIO):
|
|
_compressed_stream: typing.BinaryIO
|
|
_close_stream: bool
|
|
_header_info: CompressedHeaderInfo
|
|
_decompress_iter: typing.Iterator[bytes]
|
|
_decompressed_stream: typing.BinaryIO
|
|
_seek_position: int
|
|
|
|
def __init__(self, compressed_stream: typing.BinaryIO, header_info: typing.Optional[CompressedHeaderInfo], *, close_stream: bool = False) -> None:
|
|
super().__init__()
|
|
|
|
self._compressed_stream = compressed_stream
|
|
self._close_stream = close_stream
|
|
|
|
if header_info is not None:
|
|
self._header_info = header_info
|
|
else:
|
|
self._header_info = CompressedHeaderInfo.parse_stream(self._compressed_stream)
|
|
|
|
self._decompress_iter = decompress_stream_parsed(self._header_info, self._compressed_stream)
|
|
self._decompressed_stream = io.BytesIO()
|
|
self._seek_position = 0
|
|
|
|
# This override does nothing,
|
|
# but is needed to make mypy happy,
|
|
# otherwise it complains (apparently incorrectly) about the __enter__ definitions from IOBase and BinaryIO being incompatible with each other.
|
|
def __enter__(self: "DecompressingStream") -> "DecompressingStream":
|
|
return super().__enter__()
|
|
|
|
def close(self) -> None:
|
|
super().close()
|
|
if self._close_stream:
|
|
self._compressed_stream.close()
|
|
del self._decompress_iter
|
|
self._decompressed_stream.close()
|
|
|
|
def seekable(self) -> bool:
|
|
return True
|
|
|
|
def tell(self) -> int:
|
|
return self._seek_position
|
|
|
|
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
|
if whence == io.SEEK_SET:
|
|
if offset < 0:
|
|
raise ValueError(f"Negative seek offset not allowed with SEEK_SET: {offset}")
|
|
|
|
self._seek_position = offset
|
|
elif whence == io.SEEK_CUR:
|
|
self._seek_position += offset
|
|
elif whence == io.SEEK_END:
|
|
self._seek_position = self._header_info.decompressed_length - offset
|
|
else:
|
|
raise ValueError(f"Invalid whence value: {whence}")
|
|
|
|
self._seek_position = max(0, min(self._header_info.decompressed_length, self._seek_position))
|
|
|
|
return self._seek_position
|
|
|
|
def readable(self) -> bool:
|
|
return True
|
|
|
|
def read(self, size: typing.Optional[int] = -1) -> bytes:
|
|
if size is None:
|
|
size = -1
|
|
|
|
self._decompressed_stream.seek(0, io.SEEK_END)
|
|
|
|
if size < 0:
|
|
for chunk in self._decompress_iter:
|
|
self._decompressed_stream.write(chunk)
|
|
else:
|
|
if self._decompressed_stream.tell() - self._seek_position < size:
|
|
for chunk in self._decompress_iter:
|
|
self._decompressed_stream.write(chunk)
|
|
|
|
if self._decompressed_stream.tell() - self._seek_position >= size:
|
|
break
|
|
|
|
self._decompressed_stream.seek(self._seek_position)
|
|
ret = self._decompressed_stream.read(size)
|
|
self._seek_position += len(ret)
|
|
return ret
|