Make the generic decompression API stream-based

The non-stream-based APIs still exist as before and are not deprecated,
they just act as thin wrappers around the stream-based API.

The main rsrcfork module doesn't use the stream-based APIs yet, because
it reads each resource's data all at once and not incrementally.
This commit is contained in:
dgelessus 2019-10-02 16:28:40 +02:00
parent 6559cbc337
commit 8db1b22bdc
5 changed files with 37 additions and 27 deletions

View File

@ -1,3 +1,6 @@
import io
import typing
from . import dcmp0
from . import dcmp1
from . import dcmp2
@ -11,34 +14,46 @@ __all__ = [
# Maps 'dcmp' IDs to their corresponding Python implementations.
# Each decompressor has the signature (header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes.
# Each decompressor has the signature (header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes].
DECOMPRESSORS = {
0: dcmp0.decompress,
1: dcmp1.decompress,
2: dcmp2.decompress,
0: dcmp0.decompress_stream,
1: dcmp1.decompress_stream,
2: dcmp2.decompress_stream,
}
def decompress_parsed(header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
"""Decompress the given compressed resource data, whose header has already been removed and parsed into a CompressedHeaderInfo object."""
def decompress_stream_parsed(header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes]:
"""Decompress compressed resource data from a stream, whose header has already been read and parsed into a CompressedHeaderInfo object."""
try:
decompress_func = DECOMPRESSORS[header_info.dcmp_id]
except KeyError:
raise DecompressError(f"Unsupported 'dcmp' ID: {header_info.dcmp_id}")
decompressed = decompress_func(header_info, data, debug=debug)
if len(decompressed) != header_info.decompressed_length:
raise DecompressError(f"Actual length of decompressed data ({len(decompressed)}) does not match length stored in resource ({header_info.decompressed_length})")
return decompressed
def decompress(data: bytes, *, debug: bool=False) -> bytes:
"""Decompress the given compressed resource data."""
decompressed_length = 0
for chunk in decompress_func(header_info, stream, debug=debug):
decompressed_length += len(chunk)
yield chunk
header_info = CompressedHeaderInfo.parse(data)
if decompressed_length != header_info.decompressed_length:
raise DecompressError(f"Actual length of decompressed data ({decompressed_length}) does not match length stored in resource ({header_info.decompressed_length})")
def decompress_parsed(header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
"""Decompress the given compressed resource data, whose header has already been removed and parsed into a CompressedHeaderInfo object."""
return b"".join(decompress_stream_parsed(header_info, io.BytesIO(data), debug=debug))
def decompress_stream(stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes]:
"""Decompress compressed resource data from a stream."""
header_info = CompressedHeaderInfo.parse_stream(stream)
if debug:
print(f"Compressed resource data header: {header_info}")
return decompress_parsed(header_info, data[header_info.header_length:], debug=debug)
yield from decompress_stream_parsed(header_info, stream, debug=debug)
def decompress(data: bytes, *, debug: bool=False) -> bytes:
"""Decompress the given compressed resource data."""
return b"".join(decompress_stream(io.BytesIO(data), debug=debug))

View File

@ -37,9 +37,9 @@ STRUCT_COMPRESSED_SYSTEM_HEADER = struct.Struct(">h4s")
class CompressedHeaderInfo(object):
@classmethod
def parse(cls, data: bytes) -> "CompressedHeaderInfo":
def parse_stream(cls, stream: typing.BinaryIO) -> "CompressedHeaderInfo":
try:
signature, header_length, compression_type, decompressed_length, remainder = STRUCT_COMPRESSED_HEADER.unpack_from(data)
signature, header_length, compression_type, decompressed_length, remainder = STRUCT_COMPRESSED_HEADER.unpack(stream.read(STRUCT_COMPRESSED_HEADER.size))
except struct.error:
raise DecompressError(f"Invalid header")
if signature != COMPRESSED_SIGNATURE:
@ -61,6 +61,10 @@ class CompressedHeaderInfo(object):
else:
raise DecompressError(f"Unsupported compression type: 0x{compression_type:>04x}")
@classmethod
def parse(cls, data: bytes) -> "CompressedHeaderInfo":
return cls.parse_stream(io.BytesIO(data))
header_length: int
compression_type: int
decompressed_length: int

View File

@ -267,6 +267,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
if debug:
print(f"Decompressed {decompressed_length:#x} bytes so far")
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))

View File

@ -142,6 +142,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
if debug:
print(f"Decompressed {decompressed_length:#x} bytes so far")
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))

View File

@ -175,6 +175,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
decompress_func = _decompress_system_untagged
yield from decompress_func(common.make_peekable(stream), header_info.decompressed_length, table, debug=debug)
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))