lpch/patch_rip.py

567 lines
18 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import math
import struct
import re
from collections import defaultdict
COND_NAMES = ['Plus', 'SE', 'II', 'Portable', 'IIci', 'SuperMario',
'noPatchProtector', 'notVM', 'notAUX', 'hasHMMU', 'hasPMMU',
'hasMemoryDispatch', 'has800KDriver', 'hasFDHDDriver', 'hasIWM',
'hasEricksonOverpatchMistake', 'hasEricksonSoundMgr', 'notEricksonSoundMgr',
'using24BitHeaps', 'using32BitHeaps', 'notTERROR', 'hasTERROR', 'hasC96',
'hasPwrMgr']
global_sym_names = {}
def name(jt_offset):
retval = '%03X' % jt_offset
betterval = global_sym_names.get(jt_offset, None)
if betterval:
return retval + ' ' + betterval
else:
return retval
def parse_res_file(f):
import macresources
res = macresources.parse_rez_code(open(f, 'rb').read())
res = (r for r in res if r.type == b'lpch')
return sorted((r.id,r.data) for r in res)
def parse_raw_files(ff):
biglist = []
for f in ff:
num = ''
for char in reversed(f):
if char not in '0123456789': break
num = char + num
num = int(num)
dat = open(f, 'rb').read()
biglist.append((num, dat))
return sorted(biglist)
def exact_log(n):
if not n: return None
sh = 0
while n & 1 == 0:
sh += 2
n >>= 1
if n != 1: return None
return sh
def count_bits(n):
cnt = 0
while n:
if n & 1: cnt += 1
n >>= 1
return cnt
class Mod:
def __init__(self):
self.entry_points = []
self.references = []
self.rom_references = []
self.offset = -1
self.stop = -1
self.jt_entry = -1
self.rsrc_id = 0
def __str__(self):
return 'PROC ' + name(self.jt_entry)
class Ent:
def __init__(self):
self.offset = -1
self.jt_entry = -1
def __str__(self):
return 'ENTRY ' + name(self.jt_entry)
class Ref:
def __init__(self):
self.offset = -1
self.opcode = -1
self.jt_entry = -1
self.force_resident = False
@property
def assembly(self):
x = (
'leaY x,A0',
'leaY x,A1',
'leaY x,A2',
'leaY x,A3',
'leaY x,A4',
'leaY x,A5',
'leaY x,A6',
'leaY x,A7',
'peaY x',
'jsrY x',
'jmpY x',
'unknown11Y x',
'unknown12Y x',
'unknown13Y x',
'unknown14Y x',
'dcImportY x',
)[self.opcode]
x = x.replace('x', name(self.jt_entry))
x = x.replace('Y', 'Resident' if self.force_resident else '')
return x
def __str__(self):
return self.assembly
global_romref_names = {}
class RomRef:
def __init__(self):
self.offset = -1
self.romofs_pairs = []
def __str__(self):
retval = ','.join('(%s,$%x)' % (k, v) for k, v in self.romofs_pairs)
betterval = global_romref_names.get(retval, None)
if betterval:
return 'ROM ' + betterval
else:
return retval
parser = argparse.ArgumentParser(description='''
Very hacky.
''')
parser.add_argument('src', nargs='+', action='store', help='Source file (.rdump) or files (numbered)')
parser.add_argument('-roms', nargs='+', default=['Plus', 'SE', 'II', 'Portable', 'IIci', 'SuperMario'])
parser.add_argument('-pt', action='store_true', help='Print raw module tokens')
parser.add_argument('-pm', action='store_true', help='Print information about modules and code references')
parser.add_argument('-pr', action='store_true', help='Print information about ROM references')
parser.add_argument('-pj', action='store_true', help='Print jump table')
parser.add_argument('-pjh', action='store_true', help='Print jump table (hex)')
parser.add_argument('-pp', action='store_true', help='Print patch names')
parser.add_argument('-rh', action='store', help='LinkedPatches.lib, so we know how to name ROM references')
parser.add_argument('-sh', action='store', help='output of LinkedPatch -l, so we know how to name symbols')
parser.add_argument('-oo', action='store', help='Base destination path to dump resources as raw files')
parser.add_argument('-oc', action='store', help='Base destination path to dump code files')
parser.add_argument('-oe', action='store', help='Base destination path to dump code files with refs changed to NOPs')
parser.add_argument('-w', action='store', dest='width', type=int, default=128, help='Width in chars')
args = parser.parse_args()
if len(args.src) == 1:
lpch_list = parse_res_file(args.src[0])
else:
lpch_list = parse_raw_files(args.src)
# Check that we have the right number of declared ROMs...
roms_now = len(args.roms)
roms_at_build = int(math.log(lpch_list[-1][0] + 1, 2.0))
if roms_now != roms_at_build:
print('Warning: %d ROMs specified but there were %d at build time' % (roms_now, roms_at_build))
# Sort the ROMs so that the most inclusive ones come first
lpch_list.sort(key=lambda rsrc: (-count_bits(rsrc[0]), rsrc[0]))
########################################################################
# LinkedPatches.lib, so we know how to name ROM references
# mutates global_romref_names, which the RomRef class can read
if args.rh:
library = open(args.rh, 'rb').read()
name_map = defaultdict(list)
for m in re.finditer(rb'BIND\$([A-Za-z0-9@%]+)\$(\d+)\$(\d+)\$', library):
name_map[m.group(1).decode('ascii')].append((int(m.group(2)), int(m.group(3))))
for rname, rlist in name_map.items():
rlist.sort()
keystring = ','.join('(%s,$%x)' % (COND_NAMES[k], v) for k, v in rlist)
global_romref_names[keystring] = rname
# output of LinkedPatch -l, so we know how to name symbols
if args.sh:
for l in open(args.sh):
l = l.split()
if len(l) == 2:
sym_number = int(l[0], 16)
sym_name = l[1]
global_sym_names[sym_number] = sym_name
########################################################################
large_rom_table = []
large_jump_table = []
all_modules = []
code_list = [] # this is getting hackier and hackier
for num, data in lpch_list:
if args.oo: open(args.oo + str(num), 'wb').write(data)
if exact_log(num) is not None:
is_single = True
else:
is_single = False
if num == lpch_list[0][0]:
is_all = True
else:
is_all = False
matches_roms = []
for i, r in enumerate(args.roms):
if (num >> i) & 1: matches_roms.append(r)
idx = 0
if is_single:
num_lpch_for_this_rom, = struct.unpack_from('>H', data, offset=idx); idx += 2
counted = len([xnum for (xnum,xdata,*_) in lpch_list if xnum & num])
assert num_lpch_for_this_rom == counted
if is_all:
bound_rom_addr_table_cnt, jump_table_cnt = struct.unpack_from('>HH', data, offset=idx); idx += 4
code_size, = struct.unpack_from('>I', data, offset=idx); idx += 4
code = data[idx:idx+code_size]; idx += code_size
code_list.append(code)
if args.pt or args.pm or args.pr:
if not is_all: print()
print('lpch %d\t\t%db(%db)\t\t%s' % (num, len(data), code_size, ','.join(matches_roms)))
if args.oc:
open(args.oc + str(num), 'wb').write(code)
# do the rom table
rom_table_start, = struct.unpack_from('>H', data, offset=idx); idx += 2
if rom_table_start == 0xFFFF: rom_table_start = None
rom_table = []
if rom_table_start is not None:
while 1:
romofs_pairs = []
human_readable_idx = idx
for r in reversed(matches_roms): # data packed from newest to oldest rom
the_int = int.from_bytes(data[idx:idx+3], byteorder='big'); idx += 3
romofs_pairs.append((r, the_int & 0x7FFFFF))
romofs_pairs.reverse()
rom_table.append(romofs_pairs)
if args.pr:
rom_lookup_name = ','.join('(%s,$%x)' % (k, v) for k, v in romofs_pairs)
print(global_romref_names.get(rom_lookup_name, rom_lookup_name))
if the_int & 0x800000:
break
while len(large_rom_table) < rom_table_start + len(rom_table):
large_rom_table.append(None)
large_rom_table[rom_table_start:rom_table_start+len(rom_table)] = rom_table
if args.pr: print('ROM table entries are %d:%d' % (rom_table_start, rom_table_start+len(rom_table)))
# Figure out where all the ROM references are
rom_exception_table = []
for i in range(10):
the_int = int.from_bytes(data[idx:idx+3], byteorder='big'); idx += 3
if the_int == 0:
break
else:
rom_exception_table.append(the_int)
rom_references = [] # this is what we can salvage from the foregoing overcooked code
for code_offset in rom_exception_table:
while 1:
link, which_rom_part = struct.unpack_from('>HH', code, offset=code_offset)
rom_references.append(RomRef())
rom_references[-1].offset = code_offset
rom_references[-1].romofs_pairs = large_rom_table[which_rom_part]
if link == 0: break
code_offset += 4 + 2 * link
tokens = []
# do the exception table
while 1:
opcode = data[idx]; idx += 1
this_is_an_entry = False
if opcode <= 251:
tok = ('distance', opcode * 2)
elif opcode == 252: # skip entries in the jump table
opcode2 = data[idx]; idx += 1
if opcode2 == 0: # end of packed jump table
tok = ('end', None)
elif 1 <= opcode2 <= 254: # number of jump table entries to skip
tok = ('skipjt', opcode2)
elif opcode2 == 255: # word follows with number of jump table entries to skip
opcode3, = struct.unpack_from('>H', data, offset=idx); idx += 2
tok = ('skipjt', opcode3)
elif opcode == 253: # previous was reference list head for this module
tok = ('prev=ref_list_head', None)
elif opcode == 254: # previous was an entry, not a new module
tok = ('prev=entry_not_module', None)
elif opcode == 255: # word distance from current position in the code to next
# entry or module specified in the packed jump table follows
opcode2, = struct.unpack_from('>H', data, offset=idx); idx += 2
tok = ('distance', opcode2)
tokens.append(tok)
if tok[0] == 'end': break
# Mutate the tokens list to merge the 'prev=' tokens
for i in reversed(range(len(tokens) - 1)):
if tokens[i+1][0].startswith('prev='):
assert tokens[i][0] == 'distance'
tokens[i] = (tokens[i][0] + tokens[i+1][0][4:],) + tokens[i][1:]
del tokens[i+1]
# From here on, a 'distance' token can be treated as 'distance=module_end'
if args.pt:
daccum = 0
for i, (a, b) in enumerate(tokens):
if a.startswith('distance'):
daccum += b
print('%02d'%i, a, hex(b), '='+hex(daccum))
elif b is None:
print('%02d'%i, a)
else:
print('%02d'%i, a, hex(b))
if is_all:
patches = defaultdict(list) # mapping from jt number to (trap, cond_names)
curjt = 0
end_of_table = False
while not end_of_table:
conds = int.from_bytes(data[idx:idx+3], byteorder='big'); idx += 3
#print('conds', hex(conds))
cond_names = []
for i, n in enumerate(COND_NAMES):
if conds & (1 << i): cond_names.append(n)
cond_names = ','.join(cond_names)
while 1:
delta = data[idx]; idx += 1
#print(' delta', hex(delta))
if delta == 254:
break # get new condition set
elif delta == 255:
delta, = struct.unpack_from('>H', data, idx); idx += 2
#print(' delta2', hex(delta))
if delta == 0:
end_of_table = True; break
curjt += delta
trap, = struct.unpack_from('>H', data, idx); idx += 2
#print(' trap', hex(trap))
patches[curjt].append((trap, cond_names))
jt_offset = 0
cur_offset = 0
modules = []
modules.append(Mod())
modules[-1].offset = 0
modules[-1].__hack_refhead = -1
modules[-1].rsrc_id = num
for tok, arg in tokens:
if tok == 'skipjt':
jt_offset += arg
if tok.startswith('distance'):
if modules[-1].jt_entry == -1:
modules[-1].jt_entry = jt_offset
jt_offset += 1
cur_offset += arg
if tok == 'distance': # to end of module
modules[-1].stop = cur_offset
modules.append(Mod())
modules[-1].offset = cur_offset
modules[-1].__hack_refhead = -1
modules[-1].rsrc_id = num
if tok == 'distance=entry_not_module':
modules[-1].entry_points.append(Ent())
modules[-1].entry_points[-1].offset = cur_offset
modules[-1].entry_points[-1].jt_entry = jt_offset
jt_offset += 1
if tok == 'distance=ref_list_head':
modules[-1].__hack_refhead = cur_offset
modules.pop()
if modules: assert modules[-1].stop == code_size
for m in modules:
m.rom_references = [r for r in rom_references if m.offset <= r.offset < m.stop]
for m in modules:
if m.__hack_refhead == -1: continue
while 1:
word1, word2 = struct.unpack_from('>HH', code, offset=m.__hack_refhead)
m.references.append(Ref())
m.references[-1].offset = m.__hack_refhead
m.references[-1].jt_entry = word2 & 0xFFF
m.references[-1].opcode = word2 >> 12
m.references[-1].force_resident = bool(word1 & 0x8000)
dist_to_next = word1 & 0x7FFF
dist_to_next *= 2
if dist_to_next == 0: break
m.__hack_refhead += dist_to_next
if args.pm:
for m in modules:
print(m)
edited_code = bytearray(code)
# Now edit the code to look more sexier...
for m in modules:
for r in m.references:
try:
opcode = [0x206D,0x226D,0x246D,0x266D,0x286D,0x2A6D,0x2C6D,0x2E6D,0x2F2D,0x4EAD,0x4EED][r.opcode]
nu = struct.pack('>HH', opcode, r.jt_entry)
except IndexError:
nu = b'NqNq'
edited_code[r.offset:r.offset+4] = nu
if args.oe:
open(args.oe + str(num), 'wb').write(edited_code)
all_modules.extend(modules)
for el in large_rom_table:
assert el is not None
if args.pp:
for jt, v in sorted(patches.items()):
for trap, cond_names in v:
print(f' MakePatch {name(jt)}, _{trap:04X}, ({cond_names})')
if args.pj or args.pjh:
nums = [num for num, data in lpch_list]
CHARWID = 2.5 if args.pjh else 1 # for hex
def render_line(ofs, line):
return '%05x: %s' % (ofs, line)
def render_code(start, stop):
ofs = start
while ofs < stop:
ofs2 = ofs + args.width; ofs2 -= ofs2 % args.width; ofs2 = min(ofs2, stop)
line = code[ofs:ofs2]
if not line:
print('expected', stop, 'got', len(code))
raise ValueError()
if args.pjh:
line = ' '.join(line[o:o+2].hex() for o in range(0, len(line), 2))
else:
line = bytes(x if (32 < x and x != 127 and x != 0xF0 and x < 127) else 46 for x in line).decode('mac_roman')
line = ' ' * int(CHARWID * (ofs % args.width)) + line
yield render_line(ofs, line)
ofs = ofs2
def render_offset(ofs, line):
return '%05x: %s%s' % (ofs, ' ' * int(CHARWID * (ofs % args.width)), line)
def render_sep(ofs):
return '%05x: %s' % (ofs, '=' * args.width)
last_rsrc = -1
rsrc_print_progress = [0] * len(nums)
all_modules.sort(key=lambda mod: mod.jt_entry)
for mod in all_modules:
rsrc_idx = nums.index(mod.rsrc_id)
everything = sorted([mod] + mod.entry_points + mod.references + mod.rom_references, key=lambda x: x.offset)
code = code_list[rsrc_idx]
last_printed = 0
leftside = str(mod.rsrc_id).zfill(2) + ':'
def myprint(*args, **kwargs):
if args: args = (leftside + str(args[0]), *args[1:])
return print(*args, **kwargs)
def print_up_to(ofs):
for jank in render_code(rsrc_print_progress[rsrc_idx], ofs):
myprint(jank)
rsrc_print_progress[rsrc_idx] = ofs
if last_rsrc != mod.rsrc_id:
myprint(render_sep(mod.offset))
matches_roms = []
for i, r in enumerate(args.roms):
if (mod.rsrc_id >> i) & 1: matches_roms.append(r)
myprint(render_line(mod.offset, ','.join(matches_roms)))
# print()
last_rsrc = mod.rsrc_id
for mod_ent in everything + [None]:
print_up_to(mod_ent.offset if mod_ent else mod.stop)
if mod_ent:
myprint(render_offset(mod_ent.offset, '(' + str(mod_ent) + ')'))
try:
for trap, cond_names in patches[mod_ent.jt_entry]:
myprint(render_offset(mod_ent.offset, f'${trap:X},({cond_names})'))
except AttributeError:
pass
if not (isinstance(mod_ent, Mod) or isinstance(mod_ent, Ent)):
rsrc_print_progress[rsrc_idx] += 4 # close enough
# print(large_jump_table)
# print(large_rom_table)