mirror of
https://github.com/dgelessus/python-rsrcfork.git
synced 2024-12-29 04:29:24 +00:00
Make the generic decompression API stream-based
The non-stream-based APIs still exist as before and are not deprecated, they just act as thin wrappers around the stream-based API. The main rsrcfork module doesn't use the stream-based APIs yet, because it reads each resource's data all at once and not incrementally.
This commit is contained in:
parent
6559cbc337
commit
8db1b22bdc
@ -1,3 +1,6 @@
|
||||
import io
|
||||
import typing
|
||||
|
||||
from . import dcmp0
|
||||
from . import dcmp1
|
||||
from . import dcmp2
|
||||
@ -11,34 +14,46 @@ __all__ = [
|
||||
|
||||
|
||||
# Maps 'dcmp' IDs to their corresponding Python implementations.
|
||||
# Each decompressor has the signature (header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes.
|
||||
# Each decompressor has the signature (header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes].
|
||||
DECOMPRESSORS = {
|
||||
0: dcmp0.decompress,
|
||||
1: dcmp1.decompress,
|
||||
2: dcmp2.decompress,
|
||||
0: dcmp0.decompress_stream,
|
||||
1: dcmp1.decompress_stream,
|
||||
2: dcmp2.decompress_stream,
|
||||
}
|
||||
|
||||
|
||||
def decompress_parsed(header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
"""Decompress the given compressed resource data, whose header has already been removed and parsed into a CompressedHeaderInfo object."""
|
||||
def decompress_stream_parsed(header_info: CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes]:
|
||||
"""Decompress compressed resource data from a stream, whose header has already been read and parsed into a CompressedHeaderInfo object."""
|
||||
|
||||
try:
|
||||
decompress_func = DECOMPRESSORS[header_info.dcmp_id]
|
||||
except KeyError:
|
||||
raise DecompressError(f"Unsupported 'dcmp' ID: {header_info.dcmp_id}")
|
||||
|
||||
decompressed = decompress_func(header_info, data, debug=debug)
|
||||
if len(decompressed) != header_info.decompressed_length:
|
||||
raise DecompressError(f"Actual length of decompressed data ({len(decompressed)}) does not match length stored in resource ({header_info.decompressed_length})")
|
||||
return decompressed
|
||||
|
||||
|
||||
def decompress(data: bytes, *, debug: bool=False) -> bytes:
|
||||
"""Decompress the given compressed resource data."""
|
||||
decompressed_length = 0
|
||||
for chunk in decompress_func(header_info, stream, debug=debug):
|
||||
decompressed_length += len(chunk)
|
||||
yield chunk
|
||||
|
||||
header_info = CompressedHeaderInfo.parse(data)
|
||||
if decompressed_length != header_info.decompressed_length:
|
||||
raise DecompressError(f"Actual length of decompressed data ({decompressed_length}) does not match length stored in resource ({header_info.decompressed_length})")
|
||||
|
||||
def decompress_parsed(header_info: CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
"""Decompress the given compressed resource data, whose header has already been removed and parsed into a CompressedHeaderInfo object."""
|
||||
|
||||
return b"".join(decompress_stream_parsed(header_info, io.BytesIO(data), debug=debug))
|
||||
|
||||
def decompress_stream(stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes]:
|
||||
"""Decompress compressed resource data from a stream."""
|
||||
|
||||
header_info = CompressedHeaderInfo.parse_stream(stream)
|
||||
|
||||
if debug:
|
||||
print(f"Compressed resource data header: {header_info}")
|
||||
|
||||
return decompress_parsed(header_info, data[header_info.header_length:], debug=debug)
|
||||
yield from decompress_stream_parsed(header_info, stream, debug=debug)
|
||||
|
||||
def decompress(data: bytes, *, debug: bool=False) -> bytes:
|
||||
"""Decompress the given compressed resource data."""
|
||||
|
||||
return b"".join(decompress_stream(io.BytesIO(data), debug=debug))
|
||||
|
@ -37,9 +37,9 @@ STRUCT_COMPRESSED_SYSTEM_HEADER = struct.Struct(">h4s")
|
||||
|
||||
class CompressedHeaderInfo(object):
|
||||
@classmethod
|
||||
def parse(cls, data: bytes) -> "CompressedHeaderInfo":
|
||||
def parse_stream(cls, stream: typing.BinaryIO) -> "CompressedHeaderInfo":
|
||||
try:
|
||||
signature, header_length, compression_type, decompressed_length, remainder = STRUCT_COMPRESSED_HEADER.unpack_from(data)
|
||||
signature, header_length, compression_type, decompressed_length, remainder = STRUCT_COMPRESSED_HEADER.unpack(stream.read(STRUCT_COMPRESSED_HEADER.size))
|
||||
except struct.error:
|
||||
raise DecompressError(f"Invalid header")
|
||||
if signature != COMPRESSED_SIGNATURE:
|
||||
@ -61,6 +61,10 @@ class CompressedHeaderInfo(object):
|
||||
else:
|
||||
raise DecompressError(f"Unsupported compression type: 0x{compression_type:>04x}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data: bytes) -> "CompressedHeaderInfo":
|
||||
return cls.parse_stream(io.BytesIO(data))
|
||||
|
||||
header_length: int
|
||||
compression_type: int
|
||||
decompressed_length: int
|
||||
|
@ -267,6 +267,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
|
||||
|
||||
if debug:
|
||||
print(f"Decompressed {decompressed_length:#x} bytes so far")
|
||||
|
||||
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))
|
||||
|
@ -142,6 +142,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
|
||||
|
||||
if debug:
|
||||
print(f"Decompressed {decompressed_length:#x} bytes so far")
|
||||
|
||||
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))
|
||||
|
@ -175,6 +175,3 @@ def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.B
|
||||
decompress_func = _decompress_system_untagged
|
||||
|
||||
yield from decompress_func(common.make_peekable(stream), header_info.decompressed_length, table, debug=debug)
|
||||
|
||||
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))
|
||||
|
Loading…
Reference in New Issue
Block a user