mirror of
https://github.com/dgelessus/python-rsrcfork.git
synced 2025-02-09 02:31:09 +00:00
Refactor .dcmp2 to be stream-based
This is a little more complex than with the other decompressors, because .dcmp2 has to behave differently when at the byte before EOF. Checking whether this is the case requires lookahead, which is not easy to do with a plain IO stream. Some buffered IO streams provide a peek method for lookahead, but others don't (such as io.BytesIO). There is no standard way to wrap an already buffered IO stream to add a peek method, so we need a custom wrapper class and helper function for this purpose.
This commit is contained in:
parent
1e79dc3c50
commit
6559cbc337
@ -1,3 +1,4 @@
|
||||
import io
|
||||
import struct
|
||||
import typing
|
||||
|
||||
@ -100,6 +101,74 @@ class CompressedSystemHeaderInfo(CompressedHeaderInfo):
|
||||
return f"{type(self).__qualname__}(header_length={self.header_length}, compression_type=0x{self.compression_type:>04x}, decompressed_length={self.decompressed_length}, dcmp_id={self.dcmp_id}, parameters={self.parameters!r})"
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
class PeekableIO(typing.Protocol):
|
||||
"""Minimal protocol for binary IO streams that support the peek method.
|
||||
|
||||
The peek method is supported by various standard Python binary IO streams, such as io.BufferedReader. If a stream does not natively support the peek method, it may be wrapped using the custom helper function make_peekable.
|
||||
"""
|
||||
|
||||
def readable(self) -> bool: ...
|
||||
def read(self, size: typing.Optional[int] = ...) -> bytes: ...
|
||||
def peek(self, size: int = ...) -> bytes: ...
|
||||
|
||||
|
||||
class _PeekableIOWrapper(object):
|
||||
"""Wrapper class to add peek support to an existing stream. Do not instantiate this class directly, use the make_peekable function instead.
|
||||
|
||||
Python provides a standard io.BufferedReader class, which supports the peek method. However, according to its documentation, it only supports wrapping io.RawIOBase subclasses, and not streams which are already otherwise buffered.
|
||||
|
||||
Warning: this class does not perform any buffering of its own, outside of what is required to make peek work. It is strongly recommended to only wrap streams that are already buffered or otherwise fast to read from. In particular, raw streams (io.RawIOBase subclasses) should be wrapped using io.BufferedReader instead.
|
||||
"""
|
||||
|
||||
_wrapped: typing.BinaryIO
|
||||
_readahead: bytes
|
||||
|
||||
def __init__(self, wrapped: typing.BinaryIO) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._wrapped = wrapped
|
||||
self._readahead = b""
|
||||
|
||||
def readable(self) -> bool:
|
||||
return self._wrapped.readable()
|
||||
|
||||
def read(self, size: typing.Optional[int] = None) -> bytes:
|
||||
if size is None or size < 0:
|
||||
ret = self._readahead + self._wrapped.read()
|
||||
self._readahead = b""
|
||||
elif size <= len(self._readahead):
|
||||
ret = self._readahead[:size]
|
||||
self._readahead = self._readahead[size:]
|
||||
else:
|
||||
ret = self._readahead + self._wrapped.read(size - len(self._readahead))
|
||||
self._readahead = b""
|
||||
|
||||
return ret
|
||||
|
||||
def peek(self, size: int = -1) -> bytes:
|
||||
if not self._readahead:
|
||||
self._readahead = self._wrapped.read(io.DEFAULT_BUFFER_SIZE if size < 0 else size)
|
||||
return self._readahead
|
||||
|
||||
|
||||
def make_peekable(stream: typing.BinaryIO) -> "PeekableIO":
|
||||
"""Wrap an arbitrary binary IO stream so that it supports the peek method.
|
||||
|
||||
The stream is wrapped as efficiently as possible (or not at all if it already supports the peek method). However, in the worst case a custom wrapper class needs to be used, which may not be particularly efficient and only supports a very minimal interface. The only methods that are guaranteed to exist on the returned stream are readable, read, and peek.
|
||||
"""
|
||||
|
||||
if hasattr(stream, "peek"):
|
||||
# Stream is already peekable, nothing to be done.
|
||||
return typing.cast("PeekableIO", stream)
|
||||
elif isinstance(stream, io.RawIOBase):
|
||||
# Raw IO streams can be wrapped efficiently using BufferedReader.
|
||||
return io.BufferedReader(stream)
|
||||
else:
|
||||
# Other streams need to be wrapped using our custom wrapper class.
|
||||
return _PeekableIOWrapper(stream)
|
||||
|
||||
|
||||
def read_exact(stream: typing.BinaryIO, byte_count: int) -> bytes:
|
||||
"""Read byte_count bytes from the stream and raise an exception if too few bytes are read (i. e. if EOF was hit prematurely)."""
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import io
|
||||
import struct
|
||||
import typing
|
||||
|
||||
@ -73,65 +74,66 @@ def _split_bits(i: int) -> typing.Tuple[bool, bool, bool, bool, bool, bool, bool
|
||||
)
|
||||
|
||||
|
||||
def _decompress_system_untagged(data: bytes, decompressed_length: int, table: typing.Sequence[bytes], *, debug: bool=False) -> bytes:
|
||||
parts = []
|
||||
i = 0
|
||||
while i < len(data):
|
||||
if i == len(data) - 1 and decompressed_length % 2 != 0:
|
||||
def _decompress_system_untagged(stream: "common.PeekableIO", decompressed_length: int, table: typing.Sequence[bytes], *, debug: bool=False) -> typing.Iterator[bytes]:
|
||||
while True: # Loop is terminated when EOF is reached.
|
||||
table_index_data = stream.read(1)
|
||||
if not table_index_data:
|
||||
# End of compressed data.
|
||||
break
|
||||
elif not stream.peek(1) and decompressed_length % 2 != 0:
|
||||
# Special case: if we are at the last byte of the compressed data, and the decompressed data has an odd length, the last byte is a single literal byte, and not a table reference.
|
||||
if debug:
|
||||
print(f"Last byte: {data[-1:]}")
|
||||
parts.append(data[-1:])
|
||||
print(f"Last byte: {table_index_data}")
|
||||
yield table_index_data
|
||||
break
|
||||
|
||||
# Compressed data is untagged, every byte is a table reference.
|
||||
(table_index,) = table_index_data
|
||||
if debug:
|
||||
print(f"Reference: {data[i]} -> {table[data[i]]}")
|
||||
parts.append(table[data[i]])
|
||||
i += 1
|
||||
|
||||
return b"".join(parts)
|
||||
print(f"Reference: {table_index} -> {table[table_index]}")
|
||||
yield table[table_index]
|
||||
|
||||
def _decompress_system_tagged(data: bytes, decompressed_length: int, table: typing.Sequence[bytes], *, debug: bool=False) -> bytes:
|
||||
parts = []
|
||||
i = 0
|
||||
while i < len(data):
|
||||
if i == len(data) - 1 and decompressed_length % 2 != 0:
|
||||
def _decompress_system_tagged(stream: "common.PeekableIO", decompressed_length: int, table: typing.Sequence[bytes], *, debug: bool=False) -> typing.Iterator[bytes]:
|
||||
while True: # Loop is terminated when EOF is reached.
|
||||
tag_data = stream.read(1)
|
||||
if not tag_data:
|
||||
# End of compressed data.
|
||||
break
|
||||
elif not stream.peek(1) and decompressed_length % 2 != 0:
|
||||
# Special case: if we are at the last byte of the compressed data, and the decompressed data has an odd length, the last byte is a single literal byte, and not a tag or a table reference.
|
||||
if debug:
|
||||
print(f"Last byte: {data[-1:]}")
|
||||
parts.append(data[-1:])
|
||||
print(f"Last byte: {tag_data}")
|
||||
yield tag_data
|
||||
break
|
||||
|
||||
# Compressed data is tagged, each tag byte is followed by 8 table references and/or literals.
|
||||
tag = data[i]
|
||||
(tag,) = tag_data
|
||||
if debug:
|
||||
print(f"Tag: 0b{tag:>08b}")
|
||||
i += 1
|
||||
for is_ref in _split_bits(tag):
|
||||
if is_ref:
|
||||
# This is a table reference (a single byte that is an index into the table).
|
||||
table_index_data = stream.read(1)
|
||||
if not table_index_data:
|
||||
# End of compressed data.
|
||||
break
|
||||
(table_index,) = table_index_data
|
||||
if debug:
|
||||
print(f"Reference: {data[i]} -> {table[data[i]]}")
|
||||
parts.append(table[data[i]])
|
||||
i += 1
|
||||
print(f"Reference: {table_index} -> {table[table_index]}")
|
||||
yield table[table_index]
|
||||
else:
|
||||
# This is a literal (two uncompressed bytes that are literally copied into the output).
|
||||
# Note: if i == len(data)-1, the literal is actually only a single byte long.
|
||||
# This case is handled automatically - the slice extends one byte past the end of the data, and only one byte is returned.
|
||||
literal = stream.read(2)
|
||||
if not literal:
|
||||
# End of compressed data.
|
||||
break
|
||||
# Note: the literal may be only a single byte long if it is located exactly at EOF. This is intended and expected - the 1-byte literal is yielded normally, and on the next iteration, decompression is terminated as EOF is detected.
|
||||
if debug:
|
||||
print(f"Literal: {data[i:i+2]}")
|
||||
parts.append(data[i:i + 2])
|
||||
i += 2
|
||||
|
||||
# If the end of the compressed data is reached in the middle of a chunk, all further tag bits are ignored (they should be zero) and decompression ends.
|
||||
if i >= len(data):
|
||||
break
|
||||
|
||||
return b"".join(parts)
|
||||
print(f"Literal: {literal}")
|
||||
yield literal
|
||||
|
||||
|
||||
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
def decompress_stream(header_info: common.CompressedHeaderInfo, stream: typing.BinaryIO, *, debug: bool=False) -> typing.Iterator[bytes]:
|
||||
"""Decompress compressed data in the format used by 'dcmp' (2)."""
|
||||
|
||||
if not isinstance(header_info, common.CompressedSystemHeaderInfo):
|
||||
@ -155,18 +157,15 @@ def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug:
|
||||
print(f"Flags: {flags}")
|
||||
|
||||
if ParameterFlags.CUSTOM_TABLE in flags:
|
||||
table_start = 0
|
||||
data_start = table_start + table_count * 2
|
||||
table = []
|
||||
for i in range(table_start, data_start, 2):
|
||||
table.append(data[i:i + 2])
|
||||
for _ in range(table_count):
|
||||
table.append(common.read_exact(stream, 2))
|
||||
if debug:
|
||||
print(f"Using custom table: {table}")
|
||||
else:
|
||||
if table_count_m1 != 0:
|
||||
raise common.DecompressError(f"table_count_m1 field is {table_count_m1}, but must be zero when the default table is used")
|
||||
table = DEFAULT_TABLE
|
||||
data_start = 0
|
||||
if debug:
|
||||
print("Using default table")
|
||||
|
||||
@ -175,4 +174,7 @@ def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug:
|
||||
else:
|
||||
decompress_func = _decompress_system_untagged
|
||||
|
||||
return decompress_func(data[data_start:], header_info.decompressed_length, table, debug=debug)
|
||||
yield from decompress_func(common.make_peekable(stream), header_info.decompressed_length, table, debug=debug)
|
||||
|
||||
def decompress(header_info: common.CompressedHeaderInfo, data: bytes, *, debug: bool=False) -> bytes:
|
||||
return b"".join(decompress_stream(header_info, io.BytesIO(data), debug=debug))
|
||||
|
Loading…
x
Reference in New Issue
Block a user