diff --git a/string_compiler.py b/string_compiler.py index ebc548f..c0b39eb 100644 --- a/string_compiler.py +++ b/string_compiler.py @@ -13,10 +13,10 @@ def str_to_int(cc): return reduce(fn, reversed(cc), 0) def str_to_print(cc): - return "".join([x if x.isprintable() else "." for x in cc]) + return "".join([x if (x.isascii() and x.isprintable()) else "." for x in cc]) def or_mask(cc): - fn = lambda x, y: (x << 8) + (0x20 * y.islower()) + fn = lambda x, y: (x << 8) + (0x20 * (y.isascii() and y.islower())) return reduce(fn, reversed(cc), 0) def mask_char(asm, short_m, old, new): @@ -78,8 +78,9 @@ def generate_asm(asm, d, level): l = asm.reserve_label() if flag_i: mask = mask_char(asm, short_m, mask, or_mask(k)) v = str_to_int(k) - if v != 0: - asm.emit("cmp #${:04x}\t; '{}'".format(v, str_to_print(k)), 3) + # only valid if preceeded by a lda/ora + # if v != 0: + asm.emit("cmp #${:04x}\t; '{}'".format(v, encode_string(k)), 3) asm.bne(l) generate_asm(asm, dd, level+1) asm.emit_label(l) @@ -95,8 +96,8 @@ def generate_asm(asm, d, level): l = asm.reserve_label() if flag_i: mask = mask_char(asm, short_m, mask, or_mask(k)) v = str_to_int(k) - if v != 0: - asm.emit("cmp #${:02x}\t; '{}'".format(v, str_to_print(k)), 2) + # if v != 0: + asm.emit("cmp #${:02x}\t; '{}'".format(v, encode_string(k)), 2) asm.bne(l) generate_asm(asm, dd, level+1) asm.emit_label(l) @@ -125,13 +126,23 @@ def process(data, name): current[x] = tmp current = tmp - current[""] = { "__value__": data[k], "__key__": k } + current[""] = { "__value__": data[k], "__key__": encode_string(k) } # print(tree); asm = Assembler(name) generate_asm(asm, tree, 0) asm.finish(sys.stdout) +def decode_string(s): + global decode_map + + fn = lambda x: decode_map.get(x[1].lower(), '') + return re.sub(r"\\([xX][A-Fa-f0-9]{2}|.)", fn, s) + +def encode_string(s): + global encode_map + return "".join([encode_map.get(x, x) for x in s]) + def usage(ex=1): print("Usage: string_compiler [-i] name [file]") @@ -153,14 +164,14 @@ def read_data(f, name): if not m: err = "{}:{}: Bad data: {}".format(name,ln,line) raise Exception(err) - k = m[1] + k = orig_k = m[1] if flag_i: k = k.lower() - + k = decode_string(k) v = int(m[2]) if k in data: - err = "{}:{}: Duplicate string: {}".format(name,ln,k) + err = "{}:{}: Duplicate string: {}".format(name,ln,orig_k) raise Exception(err) data[k] = v @@ -178,6 +189,39 @@ def read_file(path): def main(): global flag_i + global decode_map + global encode_map + + decode_map = {} + for i in range(0, 256): + decode_map["x{:02x}".format(i)] = chr(i) + decode_map['\\'] = '\\' + decode_map["'"] = "'" + decode_map['"'] = '"' + decode_map['?'] = '?' + + decode_map['a'] = chr(7) + decode_map['b'] = chr(8) + decode_map['f'] = chr(12) + decode_map['n'] = chr(10) + decode_map['r'] = chr(13) + decode_map['t'] = chr(9) + decode_map['v'] = chr(11) + + encode_map = {} + for i in range(0, 20): encode_map[chr(i)] = "\\x{:02x}".format(i) + for i in range(127, 256): encode_map[chr(i)] = "\\x{:02x}".format(i) + + encode_map['\\'] = '\\\\' + encode_map[chr(7)] = '\\a' + encode_map[chr(8)] = '\\b' + encode_map[chr(12)] = '\\f' + encode_map[chr(10)] = '\\n' + encode_map[chr(13)] = '\\r' + encode_map[chr(9)] = '\\t' + encode_map[chr(11)] = '\\v' + + argv = sys.argv[1:] opts, args = getopt.getopt(argv, "i")