From a0328b8840b5b7817c27a1d68a9ac459d8f71421 Mon Sep 17 00:00:00 2001 From: Chris Pressey Date: Wed, 10 Apr 2019 08:48:33 +0100 Subject: [PATCH] Store type information in SymbolTable shared across phases. --- bin/sixtypical | 14 +-- src/sixtypical/analyzer.py | 205 +++++++++++++++++++++---------------- src/sixtypical/ast.py | 4 +- src/sixtypical/compiler.py | 109 ++++++++++++-------- src/sixtypical/fallthru.py | 5 +- src/sixtypical/model.py | 50 ++------- src/sixtypical/parser.py | 112 ++++++++++---------- 7 files changed, 263 insertions(+), 236 deletions(-) diff --git a/bin/sixtypical b/bin/sixtypical index 505a301..7a86cce 100755 --- a/bin/sixtypical +++ b/bin/sixtypical @@ -19,22 +19,22 @@ from pprint import pprint import sys import traceback -from sixtypical.parser import Parser, ParsingContext, merge_programs +from sixtypical.parser import Parser, SymbolTable, merge_programs from sixtypical.analyzer import Analyzer from sixtypical.outputter import outputter_class_for from sixtypical.compiler import Compiler def process_input_files(filenames, options): - context = ParsingContext() + symtab = SymbolTable() programs = [] for filename in options.filenames: text = open(filename).read() - parser = Parser(context, text, filename) + parser = Parser(symtab, text, filename) if options.debug: - print(context) + print(symtab) program = parser.program() programs.append(program) @@ -43,7 +43,7 @@ def process_input_files(filenames, options): program = merge_programs(programs) - analyzer = Analyzer(debug=options.debug) + analyzer = Analyzer(symtab, debug=options.debug) try: analyzer.analyze_program(program) @@ -64,7 +64,7 @@ def process_input_files(filenames, options): sys.stdout.write(json.dumps(data, indent=4, sort_keys=True, separators=(',', ':'))) sys.stdout.write("\n") - fa = FallthruAnalyzer(debug=options.debug) + fa = FallthruAnalyzer(symtab, debug=options.debug) fa.analyze_program(program) compilation_roster = fa.serialize() dump(compilation_roster) @@ -82,7 +82,7 @@ def process_input_files(filenames, options): with open(options.output, 'wb') as fh: outputter = outputter_class_for(options.output_format)(fh, start_addr=start_addr) outputter.write_prelude() - compiler = Compiler(outputter.emitter) + compiler = Compiler(symtab, outputter.emitter) compiler.compile_program(program, compilation_roster=compilation_roster) outputter.write_postlude() if options.debug: diff --git a/src/sixtypical/analyzer.py b/src/sixtypical/analyzer.py index 158232f..68ff759 100644 --- a/src/sixtypical/analyzer.py +++ b/src/sixtypical/analyzer.py @@ -80,16 +80,7 @@ class IncompatibleConstraintsError(ConstraintsError): pass -def routine_has_static(routine, ref): - if not hasattr(routine, 'statics'): - return False - for static in routine.statics: - if static.location == ref: - return True - return False - - -class Context(object): +class AnalysisContext(object): """ A location is touched if it was changed (or even potentially changed) during this routine, or some routine called by this routine. @@ -108,8 +99,8 @@ class Context(object): lists of this routine. A location can also be temporarily marked unwriteable in certain contexts, such as `for` loops. """ - def __init__(self, routines, routine, inputs, outputs, trashes): - self.routines = routines # LocationRef -> Routine (AST node) + def __init__(self, symtab, routine, inputs, outputs, trashes): + self.symtab = symtab self.routine = routine # Routine (AST node) self._touched = set() # {LocationRef} self._range = dict() # LocationRef -> (Int, Int) @@ -119,29 +110,30 @@ class Context(object): self._pointer_assoc = dict() for ref in inputs: - if ref.is_constant(): + if self.is_constant(ref): raise ConstantConstraintError(self.routine, ref.name) - self._range[ref] = ref.max_range() + self._range[ref] = self.max_range(ref) output_names = set() for ref in outputs: - if ref.is_constant(): + if self.is_constant(ref): raise ConstantConstraintError(self.routine, ref.name) output_names.add(ref.name) self._writeable.add(ref) for ref in trashes: - if ref.is_constant(): + if self.is_constant(ref): raise ConstantConstraintError(self.routine, ref.name) if ref.name in output_names: raise InconsistentConstraintsError(self.routine, ref.name) self._writeable.add(ref) def __str__(self): - return "Context(\n _touched={},\n _range={},\n _writeable={}\n)".format( + return "{}(\n _touched={},\n _range={},\n _writeable={}\n)".format( + self.__class__.__name__, LocationRef.format_set(self._touched), LocationRef.format_set(self._range), LocationRef.format_set(self._writeable) ) def to_json_data(self): - type_ = self.routine.location.type + type_ = self.symtab.fetch_global_type(self.routine.name) return { 'routine_inputs': ','.join(sorted(loc.name for loc in type_.inputs)), 'routine_outputs': ','.join(sorted(loc.name for loc in type_.outputs)), @@ -154,7 +146,7 @@ class Context(object): } def clone(self): - c = Context(self.routines, self.routine, [], [], []) + c = AnalysisContext(self.symtab, self.routine, [], [], []) c._touched = set(self._touched) c._range = dict(self._range) c._writeable = set(self._writeable) @@ -169,7 +161,6 @@ class Context(object): We do not replace the gotos_encountered for technical reasons. (In `analyze_if`, we merge those sets afterwards; at the end of `analyze_routine`, they are not distinct in the set of contexts we are updating from, and we want to retain our own.)""" - self.routines = other.routines self.routine = other.routine self._touched = set(other._touched) self._range = dict(other._range) @@ -193,9 +184,9 @@ class Context(object): exception_class = kwargs.get('exception_class', UnmeaningfulReadError) for ref in refs: # statics are always meaningful - if routine_has_static(self.routine, ref): + if self.symtab.fetch_static_ref(self.routine.name, ref.name): continue - if ref.is_constant() or ref in self.routines: + if self.is_constant(ref): pass elif isinstance(ref, LocationRef): if ref not in self._range: @@ -213,7 +204,7 @@ class Context(object): exception_class = kwargs.get('exception_class', ForbiddenWriteError) for ref in refs: # statics are always writeable - if routine_has_static(self.routine, ref): + if self.symtab.fetch_static_ref(self.routine.name, ref.name): continue if ref not in self._writeable: message = ref.name @@ -234,7 +225,7 @@ class Context(object): if outside in self._range: outside_range = self._range[outside] else: - outside_range = outside.max_range() + outside_range = self.max_range(outside) if (inside_range[0] + offset.value) < outside_range[0] or (inside_range[1] + offset.value) > outside_range[1]: raise RangeExceededError(self.routine, @@ -251,7 +242,7 @@ class Context(object): def set_meaningful(self, *refs): for ref in refs: if ref not in self._range: - self._range[ref] = ref.max_range() + self._range[ref] = self.max_range(ref) def set_top_of_range(self, ref, top): self.assert_meaningful(ref) @@ -293,12 +284,12 @@ class Context(object): if src in self._range: src_range = self._range[src] else: - src_range = src.max_range() + src_range = self.max_range(src) self._range[dest] = src_range def invalidate_range(self, ref): self.assert_meaningful(ref) - self._range[ref] = ref.max_range() + self._range[ref] = self.max_range(ref) def set_unmeaningful(self, *refs): for ref in refs: @@ -336,19 +327,6 @@ class Context(object): def has_terminated(self): return self._terminated - def assert_types_for_read_table(self, instr, src, dest, type_, offset): - if (not TableType.is_a_table_type(src.ref.type, type_)) or (not dest.type == type_): - raise TypeMismatchError(instr, '{} and {}'.format(src.ref.name, dest.name)) - self.assert_meaningful(src, src.index) - self.assert_in_range(src.index, src.ref, offset) - - def assert_types_for_update_table(self, instr, dest, type_, offset): - if not TableType.is_a_table_type(dest.ref.type, type_): - raise TypeMismatchError(instr, '{}'.format(dest.ref.name)) - self.assert_meaningful(dest.index) - self.assert_in_range(dest.index, dest.ref, offset) - self.set_written(dest.ref) - def extract(self, location): """Sets the given location as writeable in the context, and returns a 'baton' representing the previous state of context for that location. This 'baton' can be used to later restore @@ -390,18 +368,53 @@ class Context(object): def set_assoc(self, pointer, table): self._pointer_assoc[pointer] = table + def is_constant(self, ref): + """read-only means that the program cannot change the value + of a location. constant means that the value of the location + will not change during the lifetime of the program.""" + if isinstance(ref, ConstantRef): + return True + if isinstance(ref, (IndirectRef, IndexedRef)): + return False + if isinstance(ref, LocationRef): + type_ = self.symtab.fetch_global_type(ref.name) + return isinstance(type_, RoutineType) + raise NotImplementedError + + def max_range(self, ref): + if isinstance(ref, ConstantRef): + return (ref.value, ref.value) + elif self.symtab.has_static(self.routine.name, ref.name): + return self.symtab.fetch_static_type(self.routine.name, ref.name).max_range + else: + return self.symtab.fetch_global_type(ref.name).max_range + class Analyzer(object): - def __init__(self, debug=False): + def __init__(self, symtab, debug=False): + self.symtab = symtab self.current_routine = None - self.routines = {} self.debug = debug self.exit_contexts_map = {} + # - - - - helper methods - - - - + + def get_type_for_name(self, name): + if self.current_routine and self.symtab.has_static(self.current_routine.name, name): + return self.symtab.fetch_static_type(self.current_routine.name, name) + return self.symtab.fetch_global_type(name) + + def get_type(self, ref): + if isinstance(ref, ConstantRef): + return ref.type + if not isinstance(ref, LocationRef): + raise NotImplementedError + return self.get_type_for_name(ref.name) + def assert_type(self, type_, *locations): for location in locations: - if location.type != type_: + if self.get_type(location) != type_: raise TypeMismatchError(self.current_routine, location.name) def assert_affected_within(self, name, affecting_type, limiting_type): @@ -419,9 +432,23 @@ class Analyzer(object): ) raise IncompatibleConstraintsError(self.current_routine, message) + def assert_types_for_read_table(self, context, instr, src, dest, type_, offset): + if (not TableType.is_a_table_type(self.get_type(src.ref), type_)) or (not self.get_type(dest) == type_): + raise TypeMismatchError(instr, '{} and {}'.format(src.ref.name, dest.name)) + context.assert_meaningful(src, src.index) + context.assert_in_range(src.index, src.ref, offset) + + def assert_types_for_update_table(self, context, instr, dest, type_, offset): + if not TableType.is_a_table_type(self.get_type(dest.ref), type_): + raise TypeMismatchError(instr, '{}'.format(dest.ref.name)) + context.assert_meaningful(dest.index) + context.assert_in_range(dest.index, dest.ref, offset) + context.set_written(dest.ref) + + # - - - - visitor methods - - - - + def analyze_program(self, program): assert isinstance(program, Program) - self.routines = {r.location: r for r in program.routines} for routine in program.routines: context = self.analyze_routine(routine) routine.encountered_gotos = list(context.encountered_gotos()) if context else [] @@ -433,8 +460,8 @@ class Analyzer(object): return None self.current_routine = routine - type_ = routine.location.type - context = Context(self.routines, routine, type_.inputs, type_.outputs, type_.trashes) + type_ = self.get_type_for_name(routine.name) + context = AnalysisContext(self.symtab, routine, type_.inputs, type_.outputs, type_.trashes) self.exit_contexts = [] self.analyze_block(routine.block, context) @@ -478,7 +505,10 @@ class Analyzer(object): # if something was touched, then it should have been declared to be writable. for ref in context.each_touched(): - if ref not in type_.outputs and ref not in type_.trashes and not routine_has_static(routine, ref): + # FIXME once we have namedtuples, go back to comparing the ref directly! + outputs_names = [r.name for r in type_.outputs] + trashes_names = [r.name for r in type_.trashes] + if ref.name not in outputs_names and ref.name not in trashes_names and not self.symtab.has_static(routine.name, ref.name): raise ForbiddenWriteError(routine, ref.name) self.exit_contexts = None @@ -525,10 +555,10 @@ class Analyzer(object): if opcode == 'ld': if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) elif isinstance(src, IndirectRef): # copying this analysis from the matching branch in `copy`, below - if isinstance(src.ref.type, PointerType) and dest.type == TYPE_BYTE: + if isinstance(self.get_type(src.ref), PointerType) and self.get_type(dest) == TYPE_BYTE: pass else: raise TypeMismatchError(instr, (src, dest)) @@ -539,7 +569,7 @@ class Analyzer(object): context.assert_meaningful(origin) context.assert_meaningful(src.ref, REG_Y) - elif src.type != dest.type: + elif self.get_type(src) != self.get_type(dest): raise TypeMismatchError(instr, '{} and {}'.format(src.name, dest.name)) else: context.assert_meaningful(src) @@ -547,12 +577,12 @@ class Analyzer(object): context.set_written(dest, FLAG_Z, FLAG_N) elif opcode == 'st': if isinstance(dest, IndexedRef): - if src.type != TYPE_BYTE: + if self.get_type(src) != TYPE_BYTE: raise TypeMismatchError(instr, (src, dest)) - context.assert_types_for_update_table(instr, dest, TYPE_BYTE, dest.offset) + self.assert_types_for_update_table(context, instr, dest, TYPE_BYTE, dest.offset) elif isinstance(dest, IndirectRef): # copying this analysis from the matching branch in `copy`, below - if isinstance(dest.ref.type, PointerType) and src.type == TYPE_BYTE: + if isinstance(self.get_type(dest.ref), PointerType) and self.get_type(src) == TYPE_BYTE: pass else: raise TypeMismatchError(instr, (src, dest)) @@ -565,7 +595,7 @@ class Analyzer(object): context.set_touched(target) context.set_written(target) - elif src.type != dest.type: + elif self.get_type(src) != self.get_type(dest): raise TypeMismatchError(instr, '{} and {}'.format(src, dest)) else: context.set_written(dest) @@ -574,18 +604,19 @@ class Analyzer(object): elif opcode == 'add': context.assert_meaningful(src, dest, FLAG_C) if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) - elif src.type == TYPE_BYTE: + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) + elif self.get_type(src) == TYPE_BYTE: self.assert_type(TYPE_BYTE, src, dest) if dest != REG_A: context.set_touched(REG_A) context.set_unmeaningful(REG_A) else: self.assert_type(TYPE_WORD, src) - if dest.type == TYPE_WORD: + dest_type = self.get_type(dest) + if dest_type == TYPE_WORD: context.set_touched(REG_A) context.set_unmeaningful(REG_A) - elif isinstance(dest.type, PointerType): + elif isinstance(dest_type, PointerType): context.set_touched(REG_A) context.set_unmeaningful(REG_A) else: @@ -595,8 +626,8 @@ class Analyzer(object): elif opcode == 'sub': context.assert_meaningful(src, dest, FLAG_C) if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) - elif src.type == TYPE_BYTE: + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) + elif self.get_type(src) == TYPE_BYTE: self.assert_type(TYPE_BYTE, src, dest) if dest != REG_A: context.set_touched(REG_A) @@ -610,8 +641,8 @@ class Analyzer(object): elif opcode == 'cmp': context.assert_meaningful(src, dest) if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) - elif src.type == TYPE_BYTE: + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) + elif self.get_type(src) == TYPE_BYTE: self.assert_type(TYPE_BYTE, src, dest) else: self.assert_type(TYPE_WORD, src, dest) @@ -620,7 +651,7 @@ class Analyzer(object): context.set_written(FLAG_Z, FLAG_N, FLAG_C) elif opcode == 'and': if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) else: self.assert_type(TYPE_BYTE, src, dest) context.assert_meaningful(src, dest) @@ -632,7 +663,7 @@ class Analyzer(object): context.set_top_of_range(dest, context.get_top_of_range(src)) elif opcode in ('or', 'xor'): if isinstance(src, IndexedRef): - context.assert_types_for_read_table(instr, src, dest, TYPE_BYTE, src.offset) + self.assert_types_for_read_table(context, instr, src, dest, TYPE_BYTE, src.offset) else: self.assert_type(TYPE_BYTE, src, dest) context.assert_meaningful(src, dest) @@ -641,7 +672,7 @@ class Analyzer(object): elif opcode in ('inc', 'dec'): context.assert_meaningful(dest) if isinstance(dest, IndexedRef): - context.assert_types_for_update_table(instr, dest, TYPE_BYTE, dest.offset) + self.assert_types_for_update_table(context, instr, dest, TYPE_BYTE, dest.offset) context.set_written(dest.ref, FLAG_Z, FLAG_N) #context.invalidate_range(dest) else: @@ -664,7 +695,7 @@ class Analyzer(object): elif opcode in ('shl', 'shr'): context.assert_meaningful(dest, FLAG_C) if isinstance(dest, IndexedRef): - context.assert_types_for_update_table(instr, dest, TYPE_BYTE, dest.offset) + self.assert_types_for_update_table(context, instr, dest, TYPE_BYTE, dest.offset) context.set_written(dest.ref, FLAG_Z, FLAG_N, FLAG_C) #context.invalidate_range(dest) else: @@ -678,51 +709,51 @@ class Analyzer(object): # 1. check that their types are compatible if isinstance(src, (LocationRef, ConstantRef)) and isinstance(dest, IndirectRef): - if src.type == TYPE_BYTE and isinstance(dest.ref.type, PointerType): + if self.get_type(src) == TYPE_BYTE and isinstance(self.get_type(dest.ref), PointerType): pass else: raise TypeMismatchError(instr, (src, dest)) elif isinstance(src, IndirectRef) and isinstance(dest, LocationRef): - if isinstance(src.ref.type, PointerType) and dest.type == TYPE_BYTE: + if isinstance(self.get_type(src.ref), PointerType) and self.get_type(dest) == TYPE_BYTE: pass else: raise TypeMismatchError(instr, (src, dest)) elif isinstance(src, IndirectRef) and isinstance(dest, IndirectRef): - if isinstance(src.ref.type, PointerType) and isinstance(dest.ref.type, PointerType): + if isinstance(self.get_type(src.ref), PointerType) and isinstance(self.get_type(dest.ref), PointerType): pass else: raise TypeMismatchError(instr, (src, dest)) elif isinstance(src, (LocationRef, ConstantRef)) and isinstance(dest, IndexedRef): - if src.type == TYPE_WORD and TableType.is_a_table_type(dest.ref.type, TYPE_WORD): + if self.get_type(src) == TYPE_WORD and TableType.is_a_table_type(self.get_type(dest.ref), TYPE_WORD): pass - elif (isinstance(src.type, VectorType) and isinstance(dest.ref.type, TableType) and - RoutineType.executable_types_compatible(src.type.of_type, dest.ref.type.of_type)): + elif (isinstance(self.get_type(src), VectorType) and isinstance(self.get_type(dest.ref), TableType) and + RoutineType.executable_types_compatible(self.get_type(src).of_type, self.get_type(dest.ref).of_type)): pass - elif (isinstance(src.type, RoutineType) and isinstance(dest.ref.type, TableType) and - RoutineType.executable_types_compatible(src.type, dest.ref.type.of_type)): + elif (isinstance(self.get_type(src), RoutineType) and isinstance(self.get_type(dest.ref), TableType) and + RoutineType.executable_types_compatible(self.get_type(src), self.get_type(dest.ref).of_type)): pass else: raise TypeMismatchError(instr, (src, dest)) context.assert_in_range(dest.index, dest.ref, dest.offset) elif isinstance(src, IndexedRef) and isinstance(dest, LocationRef): - if TableType.is_a_table_type(src.ref.type, TYPE_WORD) and dest.type == TYPE_WORD: + if TableType.is_a_table_type(self.get_type(src.ref), TYPE_WORD) and self.get_type(dest) == TYPE_WORD: pass - elif (isinstance(src.ref.type, TableType) and isinstance(dest.type, VectorType) and - RoutineType.executable_types_compatible(src.ref.type.of_type, dest.type.of_type)): + elif (isinstance(self.get_type(src.ref), TableType) and isinstance(self.get_type(dest), VectorType) and + RoutineType.executable_types_compatible(self.get_type(src.ref).of_type, self.get_type(dest).of_type)): pass else: raise TypeMismatchError(instr, (src, dest)) context.assert_in_range(src.index, src.ref, src.offset) elif isinstance(src, (LocationRef, ConstantRef)) and isinstance(dest, LocationRef): - if src.type == dest.type: + if self.get_type(src) == self.get_type(dest): pass - elif isinstance(src.type, RoutineType) and isinstance(dest.type, VectorType): - self.assert_affected_within('inputs', src.type, dest.type.of_type) - self.assert_affected_within('outputs', src.type, dest.type.of_type) - self.assert_affected_within('trashes', src.type, dest.type.of_type) + elif isinstance(self.get_type(src), RoutineType) and isinstance(self.get_type(dest), VectorType): + self.assert_affected_within('inputs', self.get_type(src), self.get_type(dest).of_type) + self.assert_affected_within('outputs', self.get_type(src), self.get_type(dest).of_type) + self.assert_affected_within('trashes', self.get_type(src), self.get_type(dest).of_type) else: raise TypeMismatchError(instr, (src, dest)) else: @@ -789,7 +820,7 @@ class Analyzer(object): raise NotImplementedError(opcode) def analyze_call(self, instr, context): - type = instr.location.type + type = self.get_type(instr.location) if not isinstance(type, (RoutineType, VectorType)): raise TypeMismatchError(instr, instr.location) if isinstance(type, VectorType): @@ -805,7 +836,7 @@ class Analyzer(object): def analyze_goto(self, instr, context): location = instr.location - type_ = location.type + type_ = self.get_type(instr.location) if not isinstance(type_, (RoutineType, VectorType)): raise TypeMismatchError(instr, location) @@ -818,7 +849,7 @@ class Analyzer(object): # and that this routine's trashes and output constraints are a # superset of the called routine's - current_type = self.current_routine.location.type + current_type = self.get_type_for_name(self.current_routine.name) self.assert_affected_within('outputs', type_, current_type) self.assert_affected_within('trashes', type_, current_type) @@ -969,9 +1000,9 @@ class Analyzer(object): context.set_unmeaningful(REG_A) def analyze_point_into(self, instr, context): - if not isinstance(instr.pointer.type, PointerType): + if not isinstance(self.get_type(instr.pointer), PointerType): raise TypeMismatchError(instr, instr.pointer) - if not TableType.is_a_table_type(instr.table.type, TYPE_BYTE): + if not TableType.is_a_table_type(self.get_type(instr.table), TYPE_BYTE): raise TypeMismatchError(instr, instr.table) # check that pointer is not yet associated with any table. diff --git a/src/sixtypical/ast.py b/src/sixtypical/ast.py index bc9a5ee..fc5f96f 100644 --- a/src/sixtypical/ast.py +++ b/src/sixtypical/ast.py @@ -54,11 +54,11 @@ class Program(AST): class Defn(AST): - value_attrs = ('name', 'addr', 'initial', 'location',) + value_attrs = ('name', 'addr', 'initial',) class Routine(AST): - value_attrs = ('name', 'addr', 'initial', 'location',) + value_attrs = ('name', 'addr', 'initial',) children_attrs = ('statics',) child_attrs = ('block',) diff --git a/src/sixtypical/compiler.py b/src/sixtypical/compiler.py index 0d0d862..e033cb1 100644 --- a/src/sixtypical/compiler.py +++ b/src/sixtypical/compiler.py @@ -30,7 +30,8 @@ class UnsupportedOpcodeError(KeyError): class Compiler(object): - def __init__(self, emitter): + def __init__(self, symtab, emitter): + self.symtab = symtab self.emitter = emitter self.routines = {} # routine.name -> Routine self.routine_statics = {} # routine.name -> { static.name -> Label } @@ -38,7 +39,19 @@ class Compiler(object): self.trampolines = {} # Location -> Label self.current_routine = None - # helper methods + # - - - - helper methods - - - - + + def get_type_for_name(self, name): + if self.current_routine and self.symtab.has_static(self.current_routine.name, name): + return self.symtab.fetch_static_type(self.current_routine.name, name) + return self.symtab.fetch_global_type(name) + + def get_type(self, ref): + if isinstance(ref, ConstantRef): + return ref.type + if not isinstance(ref, LocationRef): + raise NotImplementedError + return self.get_type_for_name(ref.name) def addressing_mode_for_index(self, index): if index == REG_X: @@ -50,7 +63,7 @@ class Compiler(object): def compute_length_of_defn(self, defn): length = None - type_ = defn.location.type + type_ = self.get_type_for_name(defn.name) if type_ == TYPE_BYTE: length = 1 elif type_ == TYPE_WORD or isinstance(type_, (PointerType, VectorType)): @@ -74,18 +87,18 @@ class Compiler(object): else: return Absolute(label) - # visitor methods + # - - - - visitor methods - - - - def compile_program(self, program, compilation_roster=None): assert isinstance(program, Program) - defn_labels = [] + declarations = [] for defn in program.defns: length = self.compute_length_of_defn(defn) label = Label(defn.name, addr=defn.addr, length=length) self.labels[defn.name] = label - defn_labels.append((defn, label)) + declarations.append((defn, self.symtab.fetch_global_type(defn.name), label)) for routine in program.routines: self.routines[routine.name] = routine @@ -95,13 +108,15 @@ class Compiler(object): self.labels[routine.name] = label if hasattr(routine, 'statics'): + self.current_routine = routine static_labels = {} for defn in routine.statics: length = self.compute_length_of_defn(defn) label = Label(defn.name, addr=defn.addr, length=length) static_labels[defn.name] = label - defn_labels.append((defn, label)) + declarations.append((defn, self.symtab.fetch_static_type(routine.name, defn.name), label)) self.routine_statics[routine.name] = static_labels + self.current_routine = None if compilation_roster is None: compilation_roster = [['main']] + [[routine.name] for routine in program.routines if routine.name != 'main'] @@ -118,10 +133,9 @@ class Compiler(object): self.emitter.emit(RTS()) # initialized data - for defn, label in defn_labels: + for defn, type_, label in declarations: if defn.initial is not None: initial_data = None - type_ = defn.location.type if type_ == TYPE_BYTE: initial_data = Byte(defn.initial) elif type_ == TYPE_WORD: @@ -137,7 +151,7 @@ class Compiler(object): self.emitter.emit(initial_data) # uninitialized, "BSS" data - for defn, label in defn_labels: + for defn, type_, label in declarations: if defn.initial is None and defn.addr is None: self.emitter.resolve_bss_label(label) @@ -199,7 +213,7 @@ class Compiler(object): self.emitter.emit(LDA(AbsoluteX(Offset(self.get_label(src.ref.name), src.offset.value)))) elif isinstance(src, IndexedRef) and src.index == REG_Y: self.emitter.emit(LDA(AbsoluteY(Offset(self.get_label(src.ref.name), src.offset.value)))) - elif isinstance(src, IndirectRef) and isinstance(src.ref.type, PointerType): + elif isinstance(src, IndirectRef) and isinstance(self.get_type(src.ref), PointerType): self.emitter.emit(LDA(IndirectY(self.get_label(src.ref.name)))) else: self.emitter.emit(LDA(self.absolute_or_zero_page(self.get_label(src.name)))) @@ -241,7 +255,7 @@ class Compiler(object): REG_Y: AbsoluteY, }[dest.index] operand = mode_cls(Offset(self.get_label(dest.ref.name), dest.offset.value)) - elif isinstance(dest, IndirectRef) and isinstance(dest.ref.type, PointerType): + elif isinstance(dest, IndirectRef) and isinstance(self.get_type(dest.ref), PointerType): operand = IndirectY(self.get_label(dest.ref.name)) else: operand = self.absolute_or_zero_page(self.get_label(dest.name)) @@ -260,7 +274,7 @@ class Compiler(object): self.emitter.emit(ADC(mode(Offset(self.get_label(src.ref.name), src.offset.value)))) else: self.emitter.emit(ADC(Absolute(self.get_label(src.name)))) - elif isinstance(dest, LocationRef) and src.type == TYPE_BYTE and dest.type == TYPE_BYTE: + elif isinstance(dest, LocationRef) and self.get_type(src) == TYPE_BYTE and self.get_type(dest) == TYPE_BYTE: if isinstance(src, ConstantRef): dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) @@ -274,7 +288,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) else: raise UnsupportedOpcodeError(instr) - elif isinstance(dest, LocationRef) and src.type == TYPE_WORD and dest.type == TYPE_WORD: + elif isinstance(dest, LocationRef) and self.get_type(src) == TYPE_WORD and self.get_type(dest) == TYPE_WORD: if isinstance(src, ConstantRef): dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) @@ -294,7 +308,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) else: raise UnsupportedOpcodeError(instr) - elif isinstance(dest, LocationRef) and src.type == TYPE_WORD and isinstance(dest.type, PointerType): + elif isinstance(dest, LocationRef) and self.get_type(src) == TYPE_WORD and isinstance(self.get_type(dest), PointerType): if isinstance(src, ConstantRef): dest_label = self.get_label(dest.name) self.emitter.emit(LDA(ZeroPage(dest_label))) @@ -327,7 +341,7 @@ class Compiler(object): self.emitter.emit(SBC(mode(Offset(self.get_label(src.ref.name), src.offset.value)))) else: self.emitter.emit(SBC(Absolute(self.get_label(src.name)))) - elif isinstance(dest, LocationRef) and src.type == TYPE_BYTE and dest.type == TYPE_BYTE: + elif isinstance(dest, LocationRef) and self.get_type(src) == TYPE_BYTE and self.get_type(dest) == TYPE_BYTE: if isinstance(src, ConstantRef): dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) @@ -341,7 +355,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) else: raise UnsupportedOpcodeError(instr) - elif isinstance(dest, LocationRef) and src.type == TYPE_WORD and dest.type == TYPE_WORD: + elif isinstance(dest, LocationRef) and self.get_type(src) == TYPE_WORD and self.get_type(dest) == TYPE_WORD: if isinstance(src, ConstantRef): dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) @@ -409,15 +423,16 @@ class Compiler(object): def compile_call(self, instr): location = instr.location label = self.get_label(instr.location.name) - if isinstance(location.type, RoutineType): + location_type = self.get_type(location) + if isinstance(location_type, RoutineType): self.emitter.emit(JSR(Absolute(label))) - elif isinstance(location.type, VectorType): + elif isinstance(location_type, VectorType): trampoline = self.trampolines.setdefault( location, Label(location.name + '_trampoline') ) self.emitter.emit(JSR(Absolute(trampoline))) else: - raise NotImplementedError + raise NotImplementedError(location_type) def compile_goto(self, instr): self.final_goto_seen = True @@ -426,16 +441,17 @@ class Compiler(object): else: location = instr.location label = self.get_label(instr.location.name) - if isinstance(location.type, RoutineType): + location_type = self.get_type(location) + if isinstance(location_type, RoutineType): self.emitter.emit(JMP(Absolute(label))) - elif isinstance(location.type, VectorType): + elif isinstance(location_type, VectorType): self.emitter.emit(JMP(Indirect(label))) else: - raise NotImplementedError + raise NotImplementedError(location_type) def compile_cmp(self, instr, src, dest): """`instr` is only for reporting purposes""" - if isinstance(src, LocationRef) and src.type == TYPE_WORD: + if isinstance(src, LocationRef) and self.get_type(src) == TYPE_WORD: src_label = self.get_label(src.name) dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) @@ -446,7 +462,7 @@ class Compiler(object): self.emitter.emit(CMP(Absolute(Offset(src_label, 1)))) self.emitter.resolve_label(end_label) return - if isinstance(src, ConstantRef) and src.type == TYPE_WORD: + if isinstance(src, ConstantRef) and self.get_type(src) == TYPE_WORD: dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(dest_label))) self.emitter.emit(CMP(Immediate(Byte(src.low_byte())))) @@ -497,30 +513,41 @@ class Compiler(object): self.emitter.emit(DEC(Absolute(self.get_label(dest.name)))) def compile_copy(self, instr, src, dest): - if isinstance(src, ConstantRef) and isinstance(dest, IndirectRef) and src.type == TYPE_BYTE and isinstance(dest.ref.type, PointerType): + + if isinstance(src, (IndirectRef, IndexedRef)): + src_ref_type = self.get_type(src.ref) + else: + src_type = self.get_type(src) + + if isinstance(dest, (IndirectRef, IndexedRef)): + dest_ref_type = self.get_type(dest.ref) + else: + dest_type = self.get_type(dest) + + if isinstance(src, ConstantRef) and isinstance(dest, IndirectRef) and src_type == TYPE_BYTE and isinstance(dest_ref_type, PointerType): ### copy 123, [ptr] + y dest_label = self.get_label(dest.ref.name) self.emitter.emit(LDA(Immediate(Byte(src.value)))) self.emitter.emit(STA(IndirectY(dest_label))) - elif isinstance(src, LocationRef) and isinstance(dest, IndirectRef) and src.type == TYPE_BYTE and isinstance(dest.ref.type, PointerType): + elif isinstance(src, LocationRef) and isinstance(dest, IndirectRef) and src_type == TYPE_BYTE and isinstance(dest_ref_type, PointerType): ### copy b, [ptr] + y src_label = self.get_label(src.name) dest_label = self.get_label(dest.ref.name) self.emitter.emit(LDA(Absolute(src_label))) self.emitter.emit(STA(IndirectY(dest_label))) - elif isinstance(src, IndirectRef) and isinstance(dest, LocationRef) and dest.type == TYPE_BYTE and isinstance(src.ref.type, PointerType): + elif isinstance(src, IndirectRef) and isinstance(dest, LocationRef) and dest_type == TYPE_BYTE and isinstance(src_ref_type, PointerType): ### copy [ptr] + y, b src_label = self.get_label(src.ref.name) dest_label = self.get_label(dest.name) self.emitter.emit(LDA(IndirectY(src_label))) self.emitter.emit(STA(Absolute(dest_label))) - elif isinstance(src, IndirectRef) and isinstance(dest, IndirectRef) and isinstance(src.ref.type, PointerType) and isinstance(dest.ref.type, PointerType): + elif isinstance(src, IndirectRef) and isinstance(dest, IndirectRef) and isinstance(src_ref_type, PointerType) and isinstance(dest_ref_type, PointerType): ### copy [ptra] + y, [ptrb] + y src_label = self.get_label(src.ref.name) dest_label = self.get_label(dest.ref.name) self.emitter.emit(LDA(IndirectY(src_label))) self.emitter.emit(STA(IndirectY(dest_label))) - elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and src.type == TYPE_WORD and TableType.is_a_table_type(dest.ref.type, TYPE_WORD): + elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and src_type == TYPE_WORD and TableType.is_a_table_type(dest_ref_type, TYPE_WORD): ### copy w, wtab + y src_label = self.get_label(src.name) dest_label = self.get_label(dest.ref.name) @@ -529,7 +556,7 @@ class Compiler(object): self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value)))) self.emitter.emit(LDA(Absolute(Offset(src_label, 1)))) self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value + 256)))) - elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and isinstance(src.type, VectorType) and isinstance(dest.ref.type, TableType) and isinstance(dest.ref.type.of_type, VectorType): + elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and isinstance(src_type, VectorType) and isinstance(dest_ref_type, TableType) and isinstance(dest_ref_type.of_type, VectorType): ### copy vec, vtab + y # FIXME this is the exact same as above - can this be simplified? src_label = self.get_label(src.name) @@ -539,7 +566,7 @@ class Compiler(object): self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value)))) self.emitter.emit(LDA(Absolute(Offset(src_label, 1)))) self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value + 256)))) - elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and isinstance(src.type, RoutineType) and isinstance(dest.ref.type, TableType) and isinstance(dest.ref.type.of_type, VectorType): + elif isinstance(src, LocationRef) and isinstance(dest, IndexedRef) and isinstance(src_type, RoutineType) and isinstance(dest_ref_type, TableType) and isinstance(dest_ref_type.of_type, VectorType): ### copy routine, vtab + y src_label = self.get_label(src.name) dest_label = self.get_label(dest.ref.name) @@ -548,7 +575,7 @@ class Compiler(object): self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value)))) self.emitter.emit(LDA(Immediate(LowAddressByte(src_label)))) self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value + 256)))) - elif isinstance(src, ConstantRef) and isinstance(dest, IndexedRef) and src.type == TYPE_WORD and TableType.is_a_table_type(dest.ref.type, TYPE_WORD): + elif isinstance(src, ConstantRef) and isinstance(dest, IndexedRef) and src_type == TYPE_WORD and TableType.is_a_table_type(dest_ref_type, TYPE_WORD): ### copy 9999, wtab + y dest_label = self.get_label(dest.ref.name) mode = self.addressing_mode_for_index(dest.index) @@ -556,7 +583,7 @@ class Compiler(object): self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value)))) self.emitter.emit(LDA(Immediate(Byte(src.high_byte())))) self.emitter.emit(STA(mode(Offset(dest_label, dest.offset.value + 256)))) - elif isinstance(src, IndexedRef) and isinstance(dest, LocationRef) and TableType.is_a_table_type(src.ref.type, TYPE_WORD) and dest.type == TYPE_WORD: + elif isinstance(src, IndexedRef) and isinstance(dest, LocationRef) and TableType.is_a_table_type(src_ref_type, TYPE_WORD) and dest_type == TYPE_WORD: ### copy wtab + y, w src_label = self.get_label(src.ref.name) dest_label = self.get_label(dest.name) @@ -565,7 +592,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) self.emitter.emit(LDA(mode(Offset(src_label, src.offset.value + 256)))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) - elif isinstance(src, IndexedRef) and isinstance(dest, LocationRef) and isinstance(dest.type, VectorType) and isinstance(src.ref.type, TableType) and isinstance(src.ref.type.of_type, VectorType): + elif isinstance(src, IndexedRef) and isinstance(dest, LocationRef) and isinstance(dest_type, VectorType) and isinstance(src_ref_type, TableType) and isinstance(src_ref_type.of_type, VectorType): ### copy vtab + y, vec # FIXME this is the exact same as above - can this be simplified? src_label = self.get_label(src.ref.name) @@ -575,20 +602,20 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) self.emitter.emit(LDA(mode(Offset(src_label, src.offset.value + 256)))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) - elif src.type == TYPE_BYTE and dest.type == TYPE_BYTE and not isinstance(src, ConstantRef): + elif src_type == TYPE_BYTE and dest_type == TYPE_BYTE and not isinstance(src, ConstantRef): ### copy b1, b2 src_label = self.get_label(src.name) dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Absolute(src_label))) self.emitter.emit(STA(Absolute(dest_label))) - elif src.type == TYPE_WORD and dest.type == TYPE_WORD and isinstance(src, ConstantRef): + elif src_type == TYPE_WORD and dest_type == TYPE_WORD and isinstance(src, ConstantRef): ### copy 9999, w dest_label = self.get_label(dest.name) self.emitter.emit(LDA(Immediate(Byte(src.low_byte())))) self.emitter.emit(STA(Absolute(dest_label))) self.emitter.emit(LDA(Immediate(Byte(src.high_byte())))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) - elif src.type == TYPE_WORD and dest.type == TYPE_WORD and not isinstance(src, ConstantRef): + elif src_type == TYPE_WORD and dest_type == TYPE_WORD and not isinstance(src, ConstantRef): ### copy w1, w2 src_label = self.get_label(src.name) dest_label = self.get_label(dest.name) @@ -596,7 +623,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) self.emitter.emit(LDA(Absolute(Offset(src_label, 1)))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) - elif isinstance(src.type, VectorType) and isinstance(dest.type, VectorType): + elif isinstance(src_type, VectorType) and isinstance(dest_type, VectorType): ### copy v1, v2 src_label = self.get_label(src.name) dest_label = self.get_label(dest.name) @@ -604,7 +631,7 @@ class Compiler(object): self.emitter.emit(STA(Absolute(dest_label))) self.emitter.emit(LDA(Absolute(Offset(src_label, 1)))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) - elif isinstance(src.type, RoutineType) and isinstance(dest.type, VectorType): + elif isinstance(src_type, RoutineType) and isinstance(dest_type, VectorType): ### copy routine, vec src_label = self.get_label(src.name) dest_label = self.get_label(dest.name) @@ -613,7 +640,7 @@ class Compiler(object): self.emitter.emit(LDA(Immediate(LowAddressByte(src_label)))) self.emitter.emit(STA(Absolute(Offset(dest_label, 1)))) else: - raise NotImplementedError(src.type) + raise NotImplementedError(src_type) def compile_if(self, instr): cls = { diff --git a/src/sixtypical/fallthru.py b/src/sixtypical/fallthru.py index 995fbcf..66d95e2 100644 --- a/src/sixtypical/fallthru.py +++ b/src/sixtypical/fallthru.py @@ -7,7 +7,8 @@ from sixtypical.model import RoutineType class FallthruAnalyzer(object): - def __init__(self, debug=False): + def __init__(self, symtab, debug=False): + self.symtab = symtab self.debug = debug def analyze_program(self, program): @@ -16,7 +17,7 @@ class FallthruAnalyzer(object): self.fallthru_map = {} for routine in program.routines: encountered_gotos = list(routine.encountered_gotos) - if len(encountered_gotos) == 1 and isinstance(encountered_gotos[0].type, RoutineType): + if len(encountered_gotos) == 1 and isinstance(self.symtab.fetch_global_type(encountered_gotos[0].name), RoutineType): self.fallthru_map[routine.name] = encountered_gotos[0].name else: self.fallthru_map[routine.name] = None diff --git a/src/sixtypical/model.py b/src/sixtypical/model.py index f89340d..6b8724d 100644 --- a/src/sixtypical/model.py +++ b/src/sixtypical/model.py @@ -57,6 +57,8 @@ class RoutineType(Type): class VectorType(Type): """This memory location contains the address of some other type (currently, only RoutineType).""" + max_range = (0, 0) + def __init__(self, of_type): self.of_type = of_type @@ -92,6 +94,8 @@ class TableType(Type): class PointerType(Type): + max_range = (0, 0) + def __init__(self): self.name = 'pointer' @@ -100,48 +104,24 @@ class PointerType(Type): class Ref(object): - def is_constant(self): - """read-only means that the program cannot change the value - of a location. constant means that the value of the location - will not change during the lifetime of the program.""" - raise NotImplementedError("class {} must implement is_constant()".format(self.__class__.__name__)) - - def max_range(self): - raise NotImplementedError("class {} must implement max_range()".format(self.__class__.__name__)) + pass class LocationRef(Ref): def __init__(self, type, name): - self.type = type self.name = name def __eq__(self, other): - # Ordinarily there will only be one ref with a given name, - # but because we store the type in here and we want to treat - # these objects as immutable, we compare the types, too, - # just to be sure. - equal = isinstance(other, self.__class__) and other.name == self.name - if equal: - assert other.type == self.type, repr((self, other)) - return equal + return self.__class__ is other.__class__ and self.name == other.name def __hash__(self): - return hash(self.name + repr(self.type)) + return hash(self.name) def __repr__(self): - return '%s(%r, %r)' % (self.__class__.__name__, self.type, self.name) + return '%s(%r)' % (self.__class__.__name__, self.name) def __str__(self): - return "{}:{}".format(self.name, self.type) - - def is_constant(self): - return isinstance(self.type, RoutineType) - - def max_range(self): - try: - return self.type.max_range - except: - return (0, 0) + return self.name @classmethod def format_set(cls, location_refs): @@ -165,9 +145,6 @@ class IndirectRef(Ref): def name(self): return '[{}]+y'.format(self.ref.name) - def is_constant(self): - return False - class IndexedRef(Ref): def __init__(self, ref, offset, index): @@ -188,9 +165,6 @@ class IndexedRef(Ref): def name(self): return '{}+{}+{}'.format(self.ref.name, self.offset, self.index.name) - def is_constant(self): - return False - class ConstantRef(Ref): def __init__(self, type, value): @@ -208,12 +182,6 @@ class ConstantRef(Ref): def __repr__(self): return '%s(%r, %r)' % (self.__class__.__name__, self.type, self.value) - def is_constant(self): - return True - - def max_range(self): - return (self.value, self.value) - def high_byte(self): return (self.value >> 8) & 255 diff --git a/src/sixtypical/parser.py b/src/sixtypical/parser.py index c23ec54..c5034f0 100644 --- a/src/sixtypical/parser.py +++ b/src/sixtypical/parser.py @@ -28,12 +28,12 @@ class ForwardReference(object): return "%s(%r)" % (self.__class__.__name__, self.name) -class ParsingContext(object): +class SymbolTable(object): def __init__(self): - self.symbols = {} # token -> SymEntry - self.statics = {} # token -> SymEntry - self.typedefs = {} # token -> Type AST - self.consts = {} # token -> Loc + self.symbols = {} # symbol name -> SymEntry + self.statics = {} # routine name -> (symbol name -> SymEntry) + self.typedefs = {} # type name -> Type AST + self.consts = {} # const name -> ConstantRef for name in ('a', 'x', 'y'): self.symbols[name] = SymEntry(None, TYPE_BYTE) @@ -43,38 +43,55 @@ class ParsingContext(object): def __str__(self): return "Symbols: {}\nStatics: {}\nTypedefs: {}\nConsts: {}".format(self.symbols, self.statics, self.typedefs, self.consts) - def fetch_ref(self, name): - if name in self.statics: - return LocationRef(self.statics[name].type_, name) + def has_static(self, routine_name, name): + return name in self.statics.get(routine_name, {}) + + def fetch_global_type(self, name): + return self.symbols[name].type_ + + def fetch_static_type(self, routine_name, name): + return self.statics[routine_name][name].type_ + + def fetch_global_ref(self, name): if name in self.symbols: return LocationRef(self.symbols[name].type_, name) return None + def fetch_static_ref(self, routine_name, name): + routine_statics = self.statics.get(routine_name, {}) + if name in routine_statics: + return LocationRef(routine_statics[name].type_, name) + return None + class Parser(object): - def __init__(self, context, text, filename): - self.context = context + def __init__(self, symtab, text, filename): + self.symtab = symtab self.scanner = Scanner(text, filename) + self.current_routine_name = None def syntax_error(self, msg): self.scanner.syntax_error(msg) - def lookup(self, name): - model = self.context.fetch_ref(name) + def lookup(self, name, allow_forward=False, routine_name=None): + model = self.symtab.fetch_global_ref(name) + if model is None and routine_name: + model = self.symtab.fetch_static_ref(routine_name, name) + if model is None and allow_forward: + return ForwardReference(name) if model is None: self.syntax_error('Undefined symbol "{}"'.format(name)) return model - def declare(self, name, ast_node, type_, static=False): - if self.context.fetch_ref(name): + def declare(self, name, ast_node, type_): + if self.symtab.fetch_global_ref(name): self.syntax_error('Symbol "%s" already declared' % name) - if static: - self.context.statics[name] = SymEntry(ast_node, type_) - else: - self.context.symbols[name] = SymEntry(ast_node, type_) + self.symtab.symbols[name] = SymEntry(ast_node, type_) - def clear_statics(self): - self.context.statics = {} + def declare_static(self, routine_name, name, ast_node, type_): + if self.symtab.fetch_global_ref(name): + self.syntax_error('Symbol "%s" already declared' % name) + self.symtab.statics.setdefault(routine_name, {})[name] = SymEntry(ast_node, type_) # ---- symbol resolution @@ -95,10 +112,8 @@ class Parser(object): type_.outputs = set([resolve(w) for w in type_.outputs]) type_.trashes = set([resolve(w) for w in type_.trashes]) - for defn in program.defns: - backpatch_constraint_labels(defn.location.type) - for routine in program.routines: - backpatch_constraint_labels(routine.location.type) + for name, symentry in self.symtab.symbols.items(): + backpatch_constraint_labels(symentry.type_) def resolve_fwd_reference(obj, field): field_value = getattr(obj, field, None) @@ -110,7 +125,6 @@ class Parser(object): for node in program.all_children(): if isinstance(node, SingleOp): - resolve_fwd_reference(node, 'location') resolve_fwd_reference(node, 'src') resolve_fwd_reference(node, 'dest') if isinstance(node, (Call, GoTo)): @@ -127,7 +141,7 @@ class Parser(object): if self.scanner.on('const'): self.defn_const() typenames = ['byte', 'word', 'table', 'vector', 'pointer'] # 'routine', - typenames.extend(self.context.typedefs.keys()) + typenames.extend(self.symtab.typedefs.keys()) while self.scanner.on(*typenames): type_, defn = self.defn() self.declare(defn.name, defn, type_) @@ -135,9 +149,11 @@ class Parser(object): while self.scanner.consume('define'): name = self.scanner.token self.scanner.scan() + self.current_routine_name = name type_, routine = self.routine(name) self.declare(name, routine, type_) routines.append(routine) + self.current_routine_name = None self.scanner.check_type('EOF') program = Program(self.scanner.line_number, defns=defns, routines=routines) @@ -148,18 +164,18 @@ class Parser(object): self.scanner.expect('typedef') type_ = self.defn_type() name = self.defn_name() - if name in self.context.typedefs: + if name in self.symtab.typedefs: self.syntax_error('Type "%s" already declared' % name) - self.context.typedefs[name] = type_ + self.symtab.typedefs[name] = type_ return type_ def defn_const(self): self.scanner.expect('const') name = self.defn_name() - if name in self.context.consts: + if name in self.symtab.consts: self.syntax_error('Const "%s" already declared' % name) loc = self.const() - self.context.consts[name] = loc + self.symtab.consts[name] = loc return loc def defn(self): @@ -189,9 +205,7 @@ class Parser(object): if initial is not None and addr is not None: self.syntax_error("Definition cannot have both initial value and explicit address") - location = LocationRef(type_, name) - - return type_, Defn(self.scanner.line_number, name=name, addr=addr, initial=initial, location=location) + return type_, Defn(self.scanner.line_number, name=name, addr=addr, initial=initial) def const(self): if self.scanner.token in ('on', 'off'): @@ -208,8 +222,8 @@ class Parser(object): loc = ConstantRef(TYPE_WORD, int(self.scanner.token)) self.scanner.scan() return loc - elif self.scanner.token in self.context.consts: - loc = self.context.consts[self.scanner.token] + elif self.scanner.token in self.symtab.consts: + loc = self.symtab.consts[self.scanner.token] self.scanner.scan() return loc else: @@ -257,9 +271,9 @@ class Parser(object): else: type_name = self.scanner.token self.scanner.scan() - if type_name not in self.context.typedefs: + if type_name not in self.symtab.typedefs: self.syntax_error("Undefined type '%s'" % type_name) - type_ = self.context.typedefs[type_name] + type_ = self.symtab.typedefs[type_name] return type_ @@ -297,20 +311,9 @@ class Parser(object): self.scanner.scan() else: statics = self.statics() - - self.clear_statics() - for defn in statics: - self.declare(defn.name, defn, defn.location.type, static=True) block = self.block() - self.clear_statics() - addr = None - location = LocationRef(type_, name) - return type_, Routine( - self.scanner.line_number, - name=name, block=block, addr=addr, - location=location, statics=statics, - ) + return type_, Routine(self.scanner.line_number, name=name, block=block, addr=addr, statics=statics) def labels(self): accum = [] @@ -334,16 +337,12 @@ class Parser(object): return accum def locexpr(self): - if self.scanner.token in ('on', 'off', 'word') or self.scanner.token in self.context.consts or self.scanner.on_type('integer literal'): + if self.scanner.token in ('on', 'off', 'word') or self.scanner.token in self.symtab.consts or self.scanner.on_type('integer literal'): return self.const() else: name = self.scanner.token self.scanner.scan() - loc = self.context.fetch_ref(name) - if loc: - return loc - else: - return ForwardReference(name) + return self.lookup(name, allow_forward=True, routine_name=self.current_routine_name) def indlocexpr(self): if self.scanner.consume('['): @@ -361,7 +360,7 @@ class Parser(object): index = None offset = ConstantRef(TYPE_BYTE, 0) if self.scanner.consume('+'): - if self.scanner.token in self.context.consts or self.scanner.on_type('integer literal'): + if self.scanner.token in self.symtab.consts or self.scanner.on_type('integer literal'): offset = self.const() self.scanner.expect('+') index = self.locexpr() @@ -374,6 +373,7 @@ class Parser(object): type_, defn = self.defn() if defn.initial is None: self.syntax_error("Static definition {} must have initial value".format(defn)) + self.declare_static(self.current_routine_name, defn.name, defn, type_) defns.append(defn) return defns