diff --git a/src/sixtypical/ast.py b/src/sixtypical/ast.py index bbcc68a..771fc1d 100644 --- a/src/sixtypical/ast.py +++ b/src/sixtypical/ast.py @@ -1,8 +1,29 @@ # encoding: UTF-8 class AST(object): + children_attrs = () + child_attrs = () + value_attrs = () + def __init__(self, **kwargs): - self.attrs = kwargs + self.attrs = {} + for attr in self.children_attrs: + self.attrs[attr] = kwargs.pop(attr, []) + for child in self.attrs[attr]: + assert child is None or isinstance(child, AST), \ + "child %s=%r of %r is not an AST node" % (attr, child, self) + for attr in self.child_attrs: + self.attrs[attr] = kwargs.pop(attr, None) + child = self.attrs[attr] + assert child is None or isinstance(child, AST), \ + "child %s=%r of %r is not an AST node" % (attr, child, self) + for attr in self.value_attrs: + self.attrs[attr] = kwargs.pop(attr, None) + assert (not kwargs), "extra arguments supplied to {} node: {}".format(self.type, kwargs) + + @property + def type(self): + return self.__class__.__name__ def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.attrs) @@ -12,22 +33,50 @@ class AST(object): return self.attrs[name] raise AttributeError(name) + def all_children(self): + for attr in self.children_attrs: + for child in self.attrs[attr]: + yield child + for subchild in child.all_children(): + yield subchild + for attr in self.child_attrs: + child = self.attrs[attr] + yield child + for subchild in child.all_children(): + yield subchild + class Program(AST): - pass + children_attrs = ('defns', 'routines',) class Defn(AST): - pass + value_attrs = ('name', 'addr', 'initial', 'location',) class Routine(AST): - pass + value_attrs = ('name', 'addr', 'initial', 'location',) + children_attrs = ('statics',) + child_attrs = ('block',) class Block(AST): - pass + children_attrs = ('instrs',) class Instr(AST): pass + + +class SingleOp(Instr): + value_attrs = ('opcode', 'dest', 'src', 'index', 'location',) + + +class BlockOp(Instr): + value_attrs = ('opcode', 'dest', 'src', 'inverted') + child_attrs = ('block',) + + +class IfOp(Instr): + value_attrs = ('opcode', 'dest', 'src', 'inverted') + child_attrs = ('block1', 'block2',) diff --git a/src/sixtypical/parser.py b/src/sixtypical/parser.py index 480953c..8fd1059 100644 --- a/src/sixtypical/parser.py +++ b/src/sixtypical/parser.py @@ -1,6 +1,6 @@ # encoding: UTF-8 -from sixtypical.ast import Program, Defn, Routine, Block, Instr +from sixtypical.ast import Program, Defn, Routine, Block, SingleOp, BlockOp, IfOp from sixtypical.model import ( TYPE_BIT, TYPE_BYTE, TYPE_WORD, RoutineType, VectorType, TableType, BufferType, PointerType, @@ -352,8 +352,8 @@ class Parser(object): block2 = None if self.scanner.consume('else'): block2 = self.block() - return Instr(opcode='if', dest=None, src=src, - block1=block1, block2=block2, inverted=inverted) + return IfOp(opcode='if', dest=None, src=src, + block1=block1, block2=block2, inverted=inverted) elif self.scanner.consume('repeat'): inverted = False src = None @@ -364,7 +364,7 @@ class Parser(object): src = self.locexpr() else: self.scanner.expect('forever') - return Instr(opcode='repeat', dest=None, src=src, + return BlockOp(opcode='repeat', dest=None, src=src, block=block, inverted=inverted) elif self.scanner.token in ("ld",): # the same as add, sub, cmp etc below, except supports an indlocexpr for the src @@ -373,7 +373,7 @@ class Parser(object): dest = self.locexpr() self.scanner.expect(',') src = self.indlocexpr() - return Instr(opcode=opcode, dest=dest, src=src, index=None) + return SingleOp(opcode=opcode, dest=dest, src=src, index=None) elif self.scanner.token in ("add", "sub", "cmp", "and", "or", "xor"): opcode = self.scanner.token self.scanner.scan() @@ -383,25 +383,25 @@ class Parser(object): index = None if self.scanner.consume('+'): index = self.locexpr() - return Instr(opcode=opcode, dest=dest, src=src, index=index) + return SingleOp(opcode=opcode, dest=dest, src=src, index=index) elif self.scanner.token in ("st",): opcode = self.scanner.token self.scanner.scan() src = self.locexpr() self.scanner.expect(',') dest = self.indlocexpr() - return Instr(opcode=opcode, dest=dest, src=src, index=None) + return SingleOp(opcode=opcode, dest=dest, src=src, index=None) elif self.scanner.token in ("shl", "shr", "inc", "dec"): opcode = self.scanner.token self.scanner.scan() dest = self.locexpr() - return Instr(opcode=opcode, dest=dest, src=None) + return SingleOp(opcode=opcode, dest=dest, src=None) elif self.scanner.token in ("call", "goto"): opcode = self.scanner.token self.scanner.scan() name = self.scanner.token self.scanner.scan() - instr = Instr(opcode=opcode, location=name, dest=None, src=None) + instr = SingleOp(opcode=opcode, location=name, dest=None, src=None) self.backpatch_instrs.append(instr) return instr elif self.scanner.token in ("copy",): @@ -410,16 +410,16 @@ class Parser(object): src = self.indlocexpr(forward=True) self.scanner.expect(',') dest = self.indlocexpr() - instr = Instr(opcode=opcode, dest=dest, src=src) + instr = SingleOp(opcode=opcode, dest=dest, src=src) self.backpatch_instrs.append(instr) return instr elif self.scanner.consume("with"): self.scanner.expect("interrupts") self.scanner.expect("off") block = self.block() - return Instr(opcode='with-sei', dest=None, src=None, block=block) + return BlockOp(opcode='with-sei', dest=None, src=None, block=block) elif self.scanner.consume("trash"): dest = self.locexpr() - return Instr(opcode='trash', src=None, dest=dest) + return SingleOp(opcode='trash', src=None, dest=dest) else: raise ValueError('bad opcode "%s"' % self.scanner.token)