diff --git a/il65/compile.py b/il65/compile.py index 589cf9a81..aeb747284 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -14,7 +14,7 @@ import attr from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \ SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\ SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \ - check_symbol_definition, UndefinedSymbolError, process_expression + check_symbol_definition, UndefinedSymbolError, process_expression, Label from .plylex import SourceRef, print_bold from .datatypes import DataType, VarType @@ -39,13 +39,12 @@ class PlyParser: self.create_multiassigns(module) self.check_and_merge_zeropages(module) self.process_all_expressions(module) - return module # XXX - # if not self.parsing_import: - # # these shall only be done on the main module after all imports have been done: - # self.apply_directive_options(module) - # self.determine_subroutine_usage(module) - # self.semantic_check(module) - # self.allocate_zeropage_vars(module) + if not self.imported_module: + # the following shall only be done on the main module after all imports have been done: + self.apply_directive_options(module) + self.determine_subroutine_usage(module) + self.semantic_check(module) + self.allocate_zeropage_vars(module) except ParseError as x: self.handle_parse_error(x) if self.parse_errors: @@ -58,7 +57,7 @@ class PlyParser: print_bold("ERROR: {}: {}".format(sourceref, fmtstring.format(*args))) def _check_last_statement_is_return(self, last_stmt: AstNode) -> None: - if isinstance(last_stmt, Subroutine): + if isinstance(last_stmt, (Subroutine, Return, Goto)): return if isinstance(last_stmt, Directive) and last_stmt.name == "noreturn": return @@ -69,52 +68,47 @@ class PlyParser: continue if "jmp " in line or "jmp\t" in line or "rts" in line or "rti" in line: return + print(last_stmt) raise ParseError("last statement in a block/subroutine must be a return or goto, " "(or %noreturn directive to silence this error)", last_stmt.sourceref) - # def semantic_check(self, module: Module) -> None: - # # perform semantic analysis / checks on the syntactic parse tree we have so far - # # (note: symbol names have already been checked to exist when we start this) - # for node, parent in module.all_nodes(): - # previous_stmt = None - # if isinstance(node, SubCall): - # if isinstance(node.target, SymbolName): - # subdef = block.scope.lookup(stmt.target.target.name) - # self.check_subroutine_arguments(stmt, subdef) - # if isinstance(stmt, Subroutine): - # # the previous statement (if any) must be a Goto or Return - # if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)): - # raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", stmt.sourceref) - # if isinstance(previous_stmt, Subroutine): - # # the statement after a subroutine can not be some random executable instruction because it could not be reached - # if not isinstance(stmt, (Subroutine, Label, Directive, InlineAssembly, VarDef)): - # raise ParseError("statement following a subroutine can't be runnable code, " - # "at least use a label first", stmt.sourceref) - # previous_stmt = stmt - # if isinstance(stmt, IncrDecr): - # if isinstance(stmt.target, SymbolName): - # symdef = block.scope.lookup(stmt.target.name) - # if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: - # raise ParseError("cannot modify a constant", stmt.sourceref) - # - # if parent and block.name != "ZP" and not isinstance(stmt, (Return, Goto)): - # self._check_last_statement_is_return(stmt) + def semantic_check(self, module: Module) -> None: + # perform semantic analysis / checks on the syntactic parse tree we have so far + # (note: symbol names have already been checked to exist when we start this) + previous_stmt = None + for node in module.all_nodes(): + if isinstance(node, Scope): + previous_stmt = None + if node.nodes and isinstance(node.parent, (Block, Subroutine)): + self._check_last_statement_is_return(node.nodes[-1]) + elif isinstance(node, SubCall): + if isinstance(node.target, SymbolName): + subdef = node.my_scope().lookup(node.target.name) + self.check_subroutine_arguments(node, subdef) # type: ignore + elif isinstance(node, Subroutine): + # the previous statement (if any) must be a Goto or Return + if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)): + raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", node.sourceref) + elif isinstance(node, IncrDecr): + if isinstance(node.target, SymbolName): + symdef = node.my_scope().lookup(node.target.name) + if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: + raise ParseError("cannot modify a constant", node.sourceref) + previous_stmt = node def check_subroutine_arguments(self, call: SubCall, subdef: Subroutine) -> None: - # @todo must be moved to expression processing, or, restructure whole AST tree walking to make it easier to walk over everything - if len(call.arguments) != len(subdef.param_spec): + if len(call.arguments.nodes) != len(subdef.param_spec): raise ParseError("invalid number of arguments ({:d}, required: {:d})" - .format(len(call.arguments), len(subdef.param_spec)), call.sourceref) - for arg, param in zip(call.arguments, subdef.param_spec): + .format(len(call.arguments.nodes), len(subdef.param_spec)), call.sourceref) + for arg, param in zip(call.arguments.nodes, subdef.param_spec): if arg.name and arg.name != param[0]: raise ParseError("parameter name mismatch", arg.sourceref) def check_and_merge_zeropages(self, module: Module) -> None: # merge all ZP blocks into one - # XXX done: converted to new nodes zeropage = None - for block in module.all_nodes([Block]): - if block.name == "ZP": + for block in module.all_nodes(Block): + if block.name == "ZP": # type: ignore if zeropage: # merge other ZP block into first ZP block for node in block.nodes: @@ -149,12 +143,12 @@ class PlyParser: raise ParseError(str(x), vardef.sourceref) def check_all_symbolnames(self, module: Module) -> None: - for node in module.all_nodes([SymbolName]): - check_symbol_definition(node.name, node.my_scope(), node.sourceref) + for node in module.all_nodes(SymbolName): + check_symbol_definition(node.name, node.my_scope(), node.sourceref) # type: ignore def process_all_expressions(self, module: Module) -> None: # process/simplify all expressions (constant folding etc) - encountered_blocks = set() + encountered_blocks = set() # type: Set[Block] for node in module.all_nodes(): if isinstance(node, Block): parentname = (node.parent.name + ".") if node.parent else "" @@ -174,15 +168,19 @@ class PlyParser: elif isinstance(node, Assignment): lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes) if len(lvalue_types) == 1: - _, node.right = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) + _, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) + if isinstance(newright, (LiteralValue, Expression)): + node.right = newright + else: + raise TypeError("invalid coerced constant type", newright) else: for lv_dt in lvalue_types: coerce_constant_value(lv_dt, node.right, node.sourceref) + @no_type_check def create_multiassigns(self, module: Module) -> None: # create multi-assign statements from nested assignments (A=B=C=5), # and optimize TargetRegisters down to single Register if it's just one register. - # XXX done: converted to new nodes def reduce_right(assign: Assignment) -> Assignment: if isinstance(assign.right, Assignment): right = reduce_right(assign.right) @@ -190,11 +188,12 @@ class PlyParser: assign.right = right.right return assign - for node in module.all_nodes([Assignment]): + for node in module.all_nodes(Assignment): if isinstance(node.right, Assignment): multi = reduce_right(node) assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment) + @no_type_check def apply_directive_options(self, module: Module) -> None: def set_save_registers(scope: Scope, save_dir: Directive) -> None: if not scope: @@ -211,81 +210,77 @@ class PlyParser: else: scope.save_registers = True - for block, parent in module.all_scopes(): - if isinstance(block, Module): + for directive in module.all_nodes(Directive): + node = directive.my_scope().parent + if isinstance(node, Module): # process the module's directives - for directive in block.scope.filter_nodes(Directive): - if directive.name == "output": - if len(directive.args) != 1 or not isinstance(directive.args[0], str): - raise ParseError("expected one str directive argument", directive.sourceref) - if directive.args[0] == "raw": - block.format = ProgramFormat.RAW - block.address = 0xc000 - elif directive.args[0] == "prg": - block.format = ProgramFormat.PRG - block.address = 0xc000 - elif directive.args[0] == "basic": - block.format = ProgramFormat.BASIC - block.address = 0x0801 - else: - raise ParseError("invalid directive args", directive.sourceref) - elif directive.name == "address": - if len(directive.args) != 1 or type(directive.args[0]) is not int: - raise ParseError("expected one integer directive argument", directive.sourceref) - if block.format == ProgramFormat.BASIC: - raise ParseError("basic cannot have a custom load address", directive.sourceref) - block.address = directive.args[0] - attr.validate(block) - elif directive.name in "import": - pass # is processed earlier - elif directive.name == "zp": - if len(directive.args) not in (1, 2) or set(directive.args) - {"clobber", "restore"}: - raise ParseError("invalid directive args", directive.sourceref) - if "clobber" in directive.args and "restore" in directive.args: - module.zp_options = ZpOptions.CLOBBER_RESTORE - elif "clobber" in directive.args: - module.zp_options = ZpOptions.CLOBBER - elif "restore" in directive.args: - raise ParseError("invalid directive args", directive.sourceref) - elif directive.name == "saveregisters": - set_save_registers(block.scope, directive) + if directive.name == "output": + if len(directive.args) != 1 or not isinstance(directive.args[0], str): + raise ParseError("expected one str directive argument", directive.sourceref) + if directive.args[0] == "raw": + node.format = ProgramFormat.RAW + node.address = 0xc000 + elif directive.args[0] == "prg": + node.format = ProgramFormat.PRG + node.address = 0xc000 + elif directive.args[0] == "basic": + node.format = ProgramFormat.BASIC + node.address = 0x0801 else: - raise NotImplementedError(directive.name) - elif isinstance(block, Block): + raise ParseError("invalid directive args", directive.sourceref) + elif directive.name == "address": + if len(directive.args) != 1 or type(directive.args[0]) is not int: + raise ParseError("expected one integer directive argument", directive.sourceref) + if node.format == ProgramFormat.BASIC: + raise ParseError("basic cannot have a custom load address", directive.sourceref) + node.address = directive.args[0] + attr.validate(node) + elif directive.name in "import": + pass # is processed earlier + elif directive.name == "zp": + if len(directive.args) not in (1, 2) or set(directive.args) - {"clobber", "restore"}: + raise ParseError("invalid directive args", directive.sourceref) + if "clobber" in directive.args and "restore" in directive.args: + module.zp_options = ZpOptions.CLOBBER_RESTORE + elif "clobber" in directive.args: + module.zp_options = ZpOptions.CLOBBER + elif "restore" in directive.args: + raise ParseError("invalid directive args", directive.sourceref) + elif directive.name == "saveregisters": + set_save_registers(directive.my_scope(), directive) + else: + raise NotImplementedError(directive.name) + elif isinstance(node, Block): # process the block's directives - for directive in block.scope.filter_nodes(Directive): - if directive.name == "saveregisters": - set_save_registers(block.scope, directive) - elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"): - continue - else: - raise NotImplementedError(directive.name) - elif isinstance(block, Subroutine): - if block.scope: - # process the sub's directives - for directive in block.scope.filter_nodes(Directive): - if directive.name == "saveregisters": - set_save_registers(block.scope, directive) - elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"): - continue - else: - raise NotImplementedError(directive.name) + if directive.name == "saveregisters": + set_save_registers(directive.my_scope(), directive) + elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"): + continue + else: + raise NotImplementedError(directive.name) + elif isinstance(node, Subroutine): + # process the sub's directives + if directive.name == "saveregisters": + set_save_registers(directive.my_scope(), directive) + elif directive.name in ("breakpoint", "asmbinary", "asminclude", "noreturn"): + continue + else: + raise NotImplementedError(directive.name) @no_type_check def determine_subroutine_usage(self, module: Module) -> None: module.subroutine_usage.clear() - for block, parent in module.all_scopes(): - for node in block.nodes: - if isinstance(node, InlineAssembly): - self._get_subroutine_usages_from_asm(module.subroutine_usage, node, block.scope) - elif isinstance(node, SubCall): - self._get_subroutine_usages_from_subcall(module.subroutine_usage, node, block.scope) - elif isinstance(node, Goto): - self._get_subroutine_usages_from_goto(module.subroutine_usage, node, block.scope) - elif isinstance(node, Return): - self._get_subroutine_usages_from_return(module.subroutine_usage, node, block.scope) - elif isinstance(node, Assignment): - self._get_subroutine_usages_from_assignment(module.subroutine_usage, node, block.scope) + for node in module.all_nodes(): + if isinstance(node, InlineAssembly): + self._get_subroutine_usages_from_asm(module.subroutine_usage, node, node.my_scope()) + elif isinstance(node, SubCall): + self._get_subroutine_usages_from_subcall(module.subroutine_usage, node, node.my_scope()) + elif isinstance(node, Goto): + self._get_subroutine_usages_from_goto(module.subroutine_usage, node, node.my_scope()) + elif isinstance(node, Return): + self._get_subroutine_usages_from_return(module.subroutine_usage, node, node.my_scope()) + elif isinstance(node, Assignment): + self._get_subroutine_usages_from_assignment(module.subroutine_usage, node, node.my_scope()) print("----------SUBROUTINES IN USE-------------") # XXX import pprint pprint.pprint(module.subroutine_usage) # XXX @@ -293,10 +288,9 @@ class PlyParser: def _get_subroutine_usages_from_subcall(self, usages: Dict[Tuple[str, str], Set[str]], subcall: SubCall, parent_scope: Scope) -> None: - target = subcall.target.target - if isinstance(target, SymbolName): - usages[(parent_scope.name, target.name)].add(str(subcall.sourceref)) - for arg in subcall.arguments: + if isinstance(subcall.target, SymbolName): + usages[(parent_scope.name, subcall.target.name)].add(str(subcall.sourceref)) + for arg in subcall.arguments.nodes: self._get_subroutine_usages_from_expression(usages, arg.value, parent_scope) def _get_subroutine_usages_from_expression(self, usages: Dict[Tuple[str, str], Set[str]], @@ -324,6 +318,7 @@ class PlyParser: else: raise TypeError("unknown expr type to scan for sub usages", expr, expr.sourceref) + @no_type_check def _get_subroutine_usages_from_goto(self, usages: Dict[Tuple[str, str], Set[str]], goto: Goto, parent_scope: Scope) -> None: target = goto.target.target @@ -369,7 +364,6 @@ class PlyParser: usages[(namespace, symbol.name)].add(str(asmnode.sourceref)) def check_directives(self, module: Module) -> None: - # XXX done: converted to new nodes imports = set() # type: Set[str] for node in module.all_nodes(): if isinstance(node, Directive): @@ -391,13 +385,12 @@ class PlyParser: def process_imports(self, module: Module) -> None: # (recursively) imports the modules - # XXX done: converted to new nodes imported = [] - for directive in module.all_nodes([Directive]): - if directive.name == "import": - if len(directive.args) < 1: + for directive in module.all_nodes(Directive): + if directive.name == "import": # type: ignore + if len(directive.args) < 1: # type: ignore raise ParseError("missing argument(s) for import directive", directive.sourceref) - for arg in directive.args: + for arg in directive.args: # type: ignore filename = self.find_import_file(arg, directive.sourceref.file) if not filename: raise ParseError("imported file not found", directive.sourceref) diff --git a/il65/plyparse.py b/il65/plyparse.py index 33e910a9e..362945148 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -75,17 +75,14 @@ class AstNode: scope = scope.parent raise LookupError("no scope found in node ancestry") - def all_nodes(self, nodetypes: Sequence['AstNode']=None) -> Generator['AstNode', None, None]: - if nodetypes is None: - nodett = AstNode - else: - nodett = tuple(nodetypes) # type: ignore + def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]: + nodetypes = nodetypes or (AstNode, ) for node in self.nodes: - if isinstance(node, nodett): # type: ignore + if isinstance(node, nodetypes): # type: ignore yield node for node in self.nodes: if isinstance(node, AstNode): - yield from node.all_nodes(nodetypes) + yield from node.all_nodes(*nodetypes) def remove_node(self, node: 'AstNode') -> None: self.nodes.remove(node) @@ -597,15 +594,15 @@ class Return(AstNode): # one, two or three subnodes: value_A, value_X, value_Y (all three Expression) @property def value_A(self) -> Expression: - return self.nodes[0] # type: ignore + return self.nodes[0] if self.nodes else None # type: ignore @property def value_X(self) -> Expression: - return self.nodes[0] # type: ignore + return self.nodes[0] if self.nodes else None # type: ignore @property def value_Y(self) -> Expression: - return self.nodes[0] # type: ignore + return self.nodes[0] if self.nodes else None # type: ignore @attr.s(cmp=False, slots=True, repr=False) diff --git a/tests/test_core.py b/tests/test_core.py index d034e8ad8..340fa875b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -102,7 +102,7 @@ def test_char_to_bytevalue(): def test_coerce_value(): def lv(v) -> LiteralValue: - return LiteralValue(value=v, sourceref=SourceRef("test", 1, 1)) + return LiteralValue(value=v, sourceref=SourceRef("test", 1, 1)) # type: ignore assert coerce_constant_value(datatypes.DataType.BYTE, lv(0)) == (False, lv(0)) assert coerce_constant_value(datatypes.DataType.BYTE, lv(255)) == (False, lv(255)) assert coerce_constant_value(datatypes.DataType.BYTE, lv('@')) == (True, lv(64)) diff --git a/tests/test_parser.py b/tests/test_parser.py index 932d8a595..4b916bfe7 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -137,7 +137,7 @@ def test_parser(): assert sub2.lineref == "src l. 19" all_nodes = list(result.all_nodes()) assert len(all_nodes) == 12 - all_nodes = list(result.all_nodes([Subroutine])) + all_nodes = list(result.all_nodes(Subroutine)) assert len(all_nodes) == 1 assert isinstance(all_nodes[0], Subroutine) assert isinstance(all_nodes[0].parent, Scope)