hystricomorph/string_compiler.py
Kelvin Sherlock 07cad750c9 improved lowercase matching algorithm.
only ora #$20 IF the target character is lowercase.  order the comparisons so accumulator can be successively modified.
worst case, 2 loads are needed per word
"##" -> load, no ora
"A#" -> ora $0020
"#A" -> reload, ora #$2000
"AA" -> ora $#2020
2019-08-09 18:52:20 -04:00

241 lines
4.7 KiB
Python

import getopt
import sys
import re
from functools import reduce
from asm import Assembler
flag_ci = False
def printf(fmt, *args): print(fmt % args)
def c_encode(c):
if c in '\\\'': return '\\'+c
return c
def str_encode(s):
return "".join(reversed([c_encode(x) for x in s]))
def str_xx(s):
return "".join(reversed(["%02x" % (ord(x)) for x in s]))
def generate_c(d, level, preserve):
indent = " " * level
double = [x for x in d.keys() if len(x) == 2]
single = [x for x in d.keys() if len(x) == 1]
count = len(d)
if "" in d: count = count - 1
if count>0:
# if preserve: printf("%s unsigned c;", indent)
if double: printf("%s c = *(unsigned *)(cp+%d);", indent, level*2)
else: printf("%s c = *(unsigned char *)(cp+%d);", indent, level*2)
if flag_ci:
if double: printf("%s c |= 0x2020;", indent)
else: printf("%s c |= 0x20;", indent)
for k in double:
dd = d[k]
printf("%s if (c=='%s'){", indent, str_encode(k))
generate_c(dd, level+1, count>1)
printf("%s }", indent)
if single: printf("%s c &= 0xff;", indent)
for k in single:
dd = d[k]
printf("%s if (c=='%s'){", indent, str_encode(k))
generate_c(dd, level+1, count>1)
printf("%s }", indent)
rv = 0
if "" in d: rv = d[""]
printf("%s return %d", indent, rv)
def str_to_int(cc):
fn = lambda x, y: (x << 8) + ord(y)
return reduce(fn, reversed(cc), 0)
def str_to_print(cc):
return "".join([x if x.isprintable() else "." for x in cc])
def or_mask(cc):
fn = lambda x, y: (x << 8) + (0x20 * y.islower())
return reduce(fn, reversed(cc), 0)
def load_char(asm, dirty, level, short_m, old, new):
if old & ~new: dirty = True
if dirty:
if level == 0:
asm.emit("lda (cp)", 2)
else:
asm.emit("lda (cp),y", 2)
old = 0
if old == new: return new
if short_m: asm.emit("ora #${:02x}".format(new), 2)
else: asm.emit("ora #${:04x}".format(new), 3)
return new
def generate_asm(asm, d, level):
global flag_ci
double = [x for x in d.keys() if len(x) == 2]
single = [x for x in d.keys() if len(x) == 1]
short_m = single and not double
mask = 0
if flag_ci:
single.sort(key = or_mask)
double.sort(key = or_mask)
if len(single): mask = or_mask(single[0])
if len(double): mask = or_mask(double[0])
count = len(d)
if "" in d: count = count - 1
if count>0:
if short_m:
asm.emit("longa off", 0)
asm.emit("sep #$20", 2)
if level>0:
asm.emit("ldy #{}".format(level * 2), 3)
mask = load_char(asm, True, level, short_m, 0, mask)
for k in double:
dd = d[k]
l = asm.reserve_label()
mask = load_char(asm, False, level, short_m, mask, or_mask(k))
asm.emit("cmp #${:04x}\t; '{}'".format(str_to_int(k), str_to_print(k)), 3)
asm.bne(l)
generate_asm(asm, dd, level+1)
asm.emit_label(l)
if single and double:
asm.emit("longa off", 0)
asm.emit("sep #$20", 2)
short_m = True
mask = mask & 0xff
for k in single:
dd = d[k]
l = asm.reserve_label()
mask = load_char(asm, False, level, short_m, mask, or_mask(k))
asm.emit("cmp #${:02x}\t; '{}'".format(str_to_int(k), str_to_print(k)), 2)
asm.bne(l)
generate_asm(asm, dd, level+1)
asm.emit_label(l)
if short_m:
asm.emit("longa on", 0)
if "" in d:
asm.emit("ldx #{}".format(d[""]), 3)
asm.rts()
def process(data, name):
tree = {}
for k in data.keys():
chunks = [k[i*2:i*2+2] for i in range(0,len(k)+1>>1)]
current = tree
for x in chunks:
if x in current:
current = current[x]
continue
tmp = {}
current[x] = tmp
current = tmp
current[""] = data[k]
# print(tree);
asm = Assembler(name)
generate_asm(asm, tree, 0)
asm.finish(sys.stdout)
def usage(ex=1):
print("Usage: string_compiler [-i] name [file]")
sys.exit(ex)
def read_data(f, name):
global flag_ci
data = {}
ln = 0
for line in f:
ln = ln + 1
line = line.strip()
if line == "" : continue
if line[0] == "#" : continue
m = re.match(r'^"([^"]*)"\s+(\d+)$', line)
if not m:
err = "{}:{}: Bad data: {}".format(name,ln,line)
raise Exception(err)
k = m[1]
# if flag_ci: k = k.lower()
if flag_ci:
k = k.lower()
# k = "".join([chr(ord(x)|0x20) for x in k])
v = int(m[2])
if k in data:
err = "{}:{}: Duplicate string: {}".format(name,ln,k)
raise Exception(err)
data[k] = v
return data
def read_stdin():
return read_data(sys.stdin, "<stdin>")
def read_file(path):
with open(path) as f:
return read_data(f, path)
def main():
global flag_ci
argv = sys.argv[1:]
opts, args = getopt.getopt(argv, "i")
for k, v in opts:
if k == "-i": flag_ci = True
else:
usage()
if len(args) < 1 or len(args) > 2:
usage()
name = args[0]
data = {}
if len(args) == 1 or args[1] == "-":
data = read_stdin()
else:
data = read_file(args[1])
process(data, name)
sys.exit(0)
main()