diff --git a/src/sixtypical/analyzer.py b/src/sixtypical/analyzer.py index d615ca5..80197ec 100644 --- a/src/sixtypical/analyzer.py +++ b/src/sixtypical/analyzer.py @@ -8,10 +8,6 @@ from sixtypical.model import ( ) -UNINITIALIZED = 'UNINITIALIZED' -INITIALIZED = 'INITIALIZED' - - class StaticAnalysisError(ValueError): pass @@ -41,76 +37,89 @@ class TypeMismatchError(StaticAnalysisError): class Context(): + """ + A location is touched if it was changed (or even potentially + changed) during this routine, or some routine called by this routine. + + A location is meaningful if it was an input to this routine, + or if it was set to a meaningful value by some operation in this + routine (or some routine called by this routine. + + A location is writeable if it was listed in the outputs and trashes + lists of this routine. + """ def __init__(self, inputs, outputs, trashes): - self._store = {} # Ref -> INITALIZED/UNINITIALIZED - self._writeables = set() + self._touched = set() + self._meaningful = set() + self._writeable = set() for ref in inputs: - self._store.setdefault(ref, INITIALIZED) + self._meaningful.add(ref) output_names = set() for ref in outputs: output_names.add(ref.name) - self._store.setdefault(ref, UNINITIALIZED) - self._writeables.add(ref.name) + self._writeable.add(ref) for ref in trashes: if ref.name in output_names: raise UsageClashError(ref.name) - self._store.setdefault(ref, UNINITIALIZED) - self._writeables.add(ref.name) + self._writeable.add(ref) def clone(self): c = Context([], [], []) - c._store = dict(self._store) - c._writeables = set(self._writeables) + c._touched = set(self._touched) + c._meaningful = set(self._meaningful) + c._writeable = set(self._writeable) return c def set_from(self, c): - self._store = dict(c._store) - self._writeables = set(c._writeables) + self._touched = set(c._touched) + self._meaningful = set(c._meaningful) + self._writeable = set(c._writeable) - def each_initialized(self): - for key, value in self._store.iteritems(): - if value == INITIALIZED: - yield key + def each_meaningful(self): + for ref in self._meaningful: + yield ref - def assert_initialized(self, *refs, **kwargs): + def each_touched(self): + for ref in self._touched: + yield ref + + def assert_meaningful(self, *refs, **kwargs): exception_class = kwargs.get('exception_class', UninitializedAccessError) for ref in refs: if isinstance(ref, ConstantRef): pass elif isinstance(ref, LocationRef): - if self.get(ref) != INITIALIZED: + if ref not in self._meaningful: raise exception_class(ref.name) else: raise ValueError(ref) def assert_writeable(self, *refs): for ref in refs: - if ref.name not in self._writeables: + if ref not in self._writeable: raise IllegalWriteError(ref.name) - def set_initialized(self, *refs): + def set_touched(self, *refs): for ref in refs: - self.set(ref, INITIALIZED) + self._touched.add(ref) - def set_uninitialized(self, *refs): + def set_meaningful(self, *refs): for ref in refs: - self.set(ref, UNINITIALIZED) + self._meaningful.add(ref) - def get(self, ref): - if isinstance(ref, ConstantRef): - return INITIALIZED - elif isinstance(ref, LocationRef): - if ref not in self._store: - return UNINITIALIZED - return self._store[ref] - else: - raise ValueError(ref) - - def set(self, ref, value): - assert isinstance(ref, LocationRef) - self._store[ref] = value + def set_unmeaningful(self, *refs): + for ref in refs: + if ref in self._meaningful: + self._meaningful.remove(ref) + def set_written(self, *refs): + """A "helper" method which does the following common sequence for + the given refs: asserts they're all writable, and sets them all + as touched and meaningful.""" + self.assert_writeable(*refs) + self.set_touched(*refs) + self.set_meaningful(*refs) def analyze_program(program): assert isinstance(program, Program) @@ -127,7 +136,10 @@ def analyze_routine(routine, routines): context = Context(routine.inputs, routine.outputs, routine.trashes) analyze_block(routine.block, context, routines) for ref in routine.outputs: - context.assert_initialized(ref, exception_class=UninitializedOutputError) + context.assert_meaningful(ref, exception_class=UninitializedOutputError) + for ref in context.each_touched(): + if ref not in routine.outputs and ref not in routine.trashes: + raise IllegalWriteError(ref.name) def analyze_block(block, context, routines): @@ -150,9 +162,8 @@ def analyze_instr(instr, context, routines): raise TypeMismatchError((src, dest)) elif src.type != dest.type: raise TypeMismatchError((src, dest)) - context.assert_initialized(src) - context.assert_writeable(dest, FLAG_Z, FLAG_N) - context.set_initialized(dest, FLAG_Z, FLAG_N) + context.assert_meaningful(src) + context.set_written(dest, FLAG_Z, FLAG_N) elif opcode == 'st': if instr.index: if src.type == TYPE_BYTE and dest.type == TYPE_BYTE_TABLE: @@ -161,49 +172,44 @@ def analyze_instr(instr, context, routines): raise TypeMismatchError((src, dest)) elif src.type != dest.type: raise TypeMismatchError((src, dest)) - context.assert_initialized(src) - context.assert_writeable(dest) - context.set_initialized(dest) + context.assert_meaningful(src) + context.set_written(dest) elif opcode in ('add', 'sub'): - context.assert_initialized(src, dest, FLAG_C) - context.assert_writeable(dest, FLAG_Z, FLAG_N, FLAG_C, FLAG_V) - context.set_initialized(dest, FLAG_Z, FLAG_N, FLAG_C, FLAG_V) + context.assert_meaningful(src, dest, FLAG_C) + context.set_written(dest, FLAG_Z, FLAG_N, FLAG_C, FLAG_V) elif opcode in ('inc', 'dec'): - context.assert_initialized(dest) - context.assert_writeable(dest, FLAG_Z, FLAG_N) - context.set_initialized(dest, FLAG_Z, FLAG_N) + context.assert_meaningful(dest) + context.set_written(dest, FLAG_Z, FLAG_N) elif opcode == 'cmp': - context.assert_initialized(src, dest) - context.assert_writeable(FLAG_Z, FLAG_N, FLAG_C) - context.set_initialized(FLAG_Z, FLAG_N, FLAG_C) + context.assert_meaningful(src, dest) + context.set_written(FLAG_Z, FLAG_N, FLAG_C) elif opcode in ('and', 'or', 'xor'): - context.assert_initialized(src, dest) - context.assert_writeable(dest, FLAG_Z, FLAG_N) - context.set_initialized(dest, FLAG_Z, FLAG_N) + context.assert_meaningful(src, dest) + context.set_written(dest, FLAG_Z, FLAG_N) elif opcode in ('shl', 'shr'): - context.assert_initialized(dest, FLAG_C) - context.assert_writeable(dest, FLAG_Z, FLAG_N, FLAG_C) - context.set_initialized(dest, FLAG_Z, FLAG_N, FLAG_C) + context.assert_meaningful(dest, FLAG_C) + context.set_written(dest, FLAG_Z, FLAG_N, FLAG_C) elif opcode == 'call': routine = routines[instr.name] for ref in routine.inputs: - context.assert_initialized(ref) + context.assert_meaningful(ref) for ref in routine.outputs: - context.assert_writeable(ref) - context.set_initialized(ref) + context.set_written(ref) for ref in routine.trashes: context.assert_writeable(ref) - context.set_uninitialized(ref) + context.set_touched(ref) + context.set_unmeaningful(ref) elif opcode == 'if': context1 = context.clone() context2 = context.clone() analyze_block(instr.block1, context1, routines) if instr.block2 is not None: analyze_block(instr.block2, context2, routines) - for ref in context1.each_initialized(): - context2.assert_initialized(ref, exception_class=InconsistentInitializationError) - for ref in context2.each_initialized(): - context1.assert_initialized(ref, exception_class=InconsistentInitializationError) + # TODO may we need to deal with touched separately here too? + for ref in context1.each_meaningful(): + context2.assert_meaningful(ref, exception_class=InconsistentInitializationError) + for ref in context2.each_meaningful(): + context1.assert_meaningful(ref, exception_class=InconsistentInitializationError) context.set_from(context1) elif opcode == 'repeat': # it will always be executed at least once, so analyze it having @@ -222,9 +228,9 @@ def analyze_instr(instr, context, routines): pass else: raise TypeMismatchError((src, dest)) - context.assert_initialized(src) - context.assert_writeable(dest) - context.set_initialized(dest) - context.set_uninitialized(REG_A, FLAG_Z, FLAG_N) + context.assert_meaningful(src) + context.set_written(dest) + context.set_touched(REG_A, FLAG_Z, FLAG_N) + context.set_unmeaningful(REG_A, FLAG_Z, FLAG_N) else: raise NotImplementedError(opcode)