Add class for resource filters in place of lambdas

This is easier to debug (printing out a lambda doesn't show what values
it checks against) and makes it easier to check that the filter values
are valid.
This commit is contained in:
dgelessus 2019-12-21 03:46:45 +01:00
parent 2b0bbb19ed
commit 2abf6e2a06

View File

@ -83,13 +83,23 @@ def bytes_escape(bs: bytes, *, quote: typing.Optional[str]=None) -> str:
return "".join(out)
def filter_to_predicate(filter: str) -> typing.Callable[[api.Resource], bool]:
MIN_RESOURCE_ID = -0x8000
MAX_RESOURCE_ID = 0x7fff
class ResourceFilter(object):
type: bytes
min_id: int
max_id: int
name: typing.Optional[bytes]
@classmethod
def from_string(cls, filter: str) -> "ResourceFilter":
if len(filter) == 4:
restype = filter.encode("ascii")
return lambda res: res.type == restype
return cls(restype, MIN_RESOURCE_ID, MAX_RESOURCE_ID, None)
elif filter[0] == filter[-1] == "'":
restype = bytes_unescape(filter[1:-1])
return lambda res: res.type == restype
return cls(restype, MIN_RESOURCE_ID, MAX_RESOURCE_ID, None)
else:
pos = filter.find("'", 1)
if pos == -1:
@ -104,26 +114,45 @@ def filter_to_predicate(filter: str) -> typing.Callable[[api.Resource], bool]:
f"Invalid filter {filter!r}: Resource type is not a single-quoted type identifier: {restype_str!r}")
restype = bytes_unescape(restype_str[1:-1])
if len(restype) != 4:
raise ValueError(
f"Invalid filter {filter!r}: Type identifier must be 4 bytes after replacing escapes, got {len(restype)} bytes: {restype!r}")
if resid_str[0] != "(" or resid_str[-1] != ")":
raise ValueError(f"Invalid filter {filter!r}: Resource ID must be parenthesized")
resid_str = resid_str[1:-1]
if resid_str[0] == resid_str[-1] == '"':
name = bytes_unescape(resid_str[1:-1])
return lambda res: res.type == restype and res.name == name
return cls(restype, MIN_RESOURCE_ID, MAX_RESOURCE_ID, name)
elif ":" in resid_str:
if resid_str.count(":") > 1:
raise ValueError(f"Invalid filter {filter!r}: Too many colons in ID range expression: {resid_str!r}")
start_str, end_str = resid_str.split(":")
start, end = int(start_str), int(end_str)
return lambda res: res.type == restype and start <= res.id <= end
return cls(restype, start, end, None)
else:
resid = int(resid_str)
return lambda res: res.type == restype and res.id == resid
return cls(restype, resid, resid, None)
def __init__(self, restype: bytes, min_id: int, max_id: int, name: typing.Optional[bytes]) -> None:
super().__init__()
if len(restype) != 4:
raise ValueError(f"Invalid filter: Type code must be exactly 4 bytes long, not {len(restype)} bytes: {restype!r}")
elif min_id < MIN_RESOURCE_ID:
raise ValueError(f"Invalid filter: Resource ID lower bound ({min_id}) cannot be lower than {MIN_RESOURCE_ID}")
elif max_id > MAX_RESOURCE_ID:
raise ValueError(f"Invalid filter: Resource ID upper bound ({max_id}) cannot be greater than {MAX_RESOURCE_ID}")
elif min_id > max_id:
raise ValueError(f"Invalid filter: Resource ID lower bound ({min_id}) cannot be greater than upper bound ({max_id})")
self.type = restype
self.min_id = min_id
self.max_id = max_id
self.name = name
def __repr__(self) -> str:
return f"{type(self).__name__}({self.type!r}, {self.min_id!r}, {self.max_id!r}, {self.name!r})"
def matches(self, res: api.Resource) -> bool:
return res.type == self.type and self.min_id <= res.id <= self.max_id and (self.name is None or res.name == self.name)
def filter_resources(rf: api.ResourceFile, filters: typing.Sequence[str]) -> typing.Iterable[api.Resource]:
if not filters:
@ -131,11 +160,11 @@ def filter_resources(rf: api.ResourceFile, filters: typing.Sequence[str]) -> typ
for reses in rf.values():
yield from reses.values()
else:
preds = [filter_to_predicate(filter) for filter in filters]
filter_objs = [ResourceFilter.from_string(filter) for filter in filters]
for reses in rf.values():
for res in reses.values():
if any(pred(res) for pred in preds):
if any(filter_obj.matches(res) for filter_obj in filter_objs):
yield res
def hexdump(data: bytes) -> None: