diff --git a/il65/compile.py b/il65/compile.py index cc12d3e7e..b5a8174e8 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -13,7 +13,8 @@ from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check import attr from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \ SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\ - SymbolName, Dereference, AddressOf, IncrDecr, Label, AstNode, datatype_of, coerce_constant_value, UndefinedSymbolError + SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \ + check_symbol_definition, UndefinedSymbolError from .plylex import SourceRef, print_bold from .datatypes import DataType, VarType @@ -23,9 +24,9 @@ class CompileError(Exception): class PlyParser: - def __init__(self, parsing_import: bool=False) -> None: + def __init__(self, imported_module: bool=False) -> None: self.parse_errors = 0 - self.parsing_import = parsing_import + self.imported_module = imported_module def parse_file(self, filename: str) -> Module: print("parsing:", filename) @@ -34,15 +35,17 @@ class PlyParser: module = parse_file(filename, self.lexer_error) self.check_directives(module) self.process_imports(module) + self.check_all_symbolnames(module) self.create_multiassigns(module) self.check_and_merge_zeropages(module) self.process_all_expressions_and_symbolnames(module) - if not self.parsing_import: - # these shall only be done on the main module after all imports have been done: - self.apply_directive_options(module) - self.determine_subroutine_usage(module) - self.semantic_check(module) - self.allocate_zeropage_vars(module) + return module # XXX + # if not self.parsing_import: + # # these shall only be done on the main module after all imports have been done: + # self.apply_directive_options(module) + # self.determine_subroutine_usage(module) + # self.semantic_check(module) + # self.allocate_zeropage_vars(module) except ParseError as x: self.handle_parse_error(x) if self.parse_errors: @@ -69,35 +72,33 @@ class PlyParser: raise ParseError("last statement in a block/subroutine must be a return or goto, " "(or %noreturn directive to silence this error)", last_stmt.sourceref) - def semantic_check(self, module: Module) -> None: - # perform semantic analysis / checks on the syntactic parse tree we have so far - # (note: symbol names have already been checked to exist when we start this) - for block, parent in module.all_scopes(): - assert isinstance(block, (Module, Block, Subroutine)) - assert parent is None or isinstance(parent, (Module, Block, Subroutine)) - previous_stmt = None - for stmt in block.nodes: - if isinstance(stmt, SubCall): - if isinstance(stmt.target.target, SymbolName): - subdef = block.scope.lookup(stmt.target.target.name) - self.check_subroutine_arguments(stmt, subdef) - if isinstance(stmt, Subroutine): - # the previous statement (if any) must be a Goto or Return - if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)): - raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", stmt.sourceref) - if isinstance(previous_stmt, Subroutine): - # the statement after a subroutine can not be some random executable instruction because it could not be reached - if not isinstance(stmt, (Subroutine, Label, Directive, InlineAssembly, VarDef)): - raise ParseError("statement following a subroutine can't be runnable code, " - "at least use a label first", stmt.sourceref) - previous_stmt = stmt - if isinstance(stmt, IncrDecr): - if isinstance(stmt.target, SymbolName): - symdef = block.scope.lookup(stmt.target.name) - if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: - raise ParseError("cannot modify a constant", stmt.sourceref) - if parent and block.name != "ZP" and not isinstance(stmt, (Return, Goto)): - self._check_last_statement_is_return(stmt) + # def semantic_check(self, module: Module) -> None: + # # perform semantic analysis / checks on the syntactic parse tree we have so far + # # (note: symbol names have already been checked to exist when we start this) + # for node, parent in module.all_nodes(): + # previous_stmt = None + # if isinstance(node, SubCall): + # if isinstance(node.target, SymbolName): + # subdef = block.scope.lookup(stmt.target.target.name) + # self.check_subroutine_arguments(stmt, subdef) + # if isinstance(stmt, Subroutine): + # # the previous statement (if any) must be a Goto or Return + # if previous_stmt and not isinstance(previous_stmt, (Goto, Return, VarDef, Subroutine)): + # raise ParseError("statement preceding subroutine must be a goto or return or another subroutine", stmt.sourceref) + # if isinstance(previous_stmt, Subroutine): + # # the statement after a subroutine can not be some random executable instruction because it could not be reached + # if not isinstance(stmt, (Subroutine, Label, Directive, InlineAssembly, VarDef)): + # raise ParseError("statement following a subroutine can't be runnable code, " + # "at least use a label first", stmt.sourceref) + # previous_stmt = stmt + # if isinstance(stmt, IncrDecr): + # if isinstance(stmt.target, SymbolName): + # symdef = block.scope.lookup(stmt.target.name) + # if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: + # raise ParseError("cannot modify a constant", stmt.sourceref) + # + # if parent and block.name != "ZP" and not isinstance(stmt, (Return, Goto)): + # self._check_last_statement_is_return(stmt) def check_subroutine_arguments(self, call: SubCall, subdef: Subroutine) -> None: # @todo must be moved to expression processing, or, restructure whole AST tree walking to make it easier to walk over everything @@ -110,8 +111,9 @@ class PlyParser: def check_and_merge_zeropages(self, module: Module) -> None: # merge all ZP blocks into one + # XXX done: converted to new nodes zeropage = None - for block in list(module.scope.filter_nodes(Block)): + for block in module.all_nodes([Block]): if block.name == "ZP": if zeropage: # merge other ZP block into first ZP block @@ -124,7 +126,7 @@ class PlyParser: raise ParseError("only variables and directives allowed in zeropage block", node.sourceref) else: zeropage = block - module.scope.remove_node(block) + block.parent.remove_node(block) if zeropage: # add the zero page again, as the very first block module.scope.add_node(zeropage, 0) @@ -146,37 +148,42 @@ class PlyParser: except CompileError as x: raise ParseError(str(x), vardef.sourceref) - @no_type_check + def check_all_symbolnames(self, module: Module) -> None: + for node in module.all_nodes([SymbolName]): + check_symbol_definition(node.name, node.my_scope(), node.sourceref) + def process_all_expressions_and_symbolnames(self, module: Module) -> None: - # process/simplify all expressions (constant folding etc), and check all symbol names + # process/simplify all expressions (constant folding etc) encountered_blocks = set() - for block, parent in module.all_scopes(): - parentname = (parent.name + ".") if parent else "" - blockname = parentname + block.name - if blockname in encountered_blocks: - raise ValueError("block names not unique:", blockname) - encountered_blocks.add(blockname) - for node in block.nodes: - try: - node.verify_symbol_names(block.scope) - node.process_expressions(block.scope) - except ParseError: - raise - except Exception as x: - self.handle_internal_error(x, "process_expressions of node {} in block {}".format(node, block.name)) - if isinstance(node, IncrDecr) and node.howmuch not in (0, 1): - _, node.howmuch = coerce_constant_value(datatype_of(node.target, block.scope), node.howmuch, node.sourceref) - elif isinstance(node, Assignment): - lvalue_types = set(datatype_of(lv, block.scope) for lv in node.left) - if len(lvalue_types) == 1: - _, node.right = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) - else: - for lv_dt in lvalue_types: - coerce_constant_value(lv_dt, node.right, node.sourceref) + for node in module.all_nodes(): + if isinstance(node, Block): + parentname = (node.parent.name + ".") if node.parent else "" + blockname = parentname + node.name + if blockname in encountered_blocks: + raise ValueError("block names not unique:", blockname) + encountered_blocks.add(blockname) + elif isinstance(node, Expression): + print("EXPRESSION", node) # XXX + # try: + # node.process_expressions(block.scope) + # except ParseError: + # raise + # except Exception as x: + # self.handle_internal_error(x, "process_expressions of node {} in block {}".format(node, block.name)) + elif isinstance(node, IncrDecr) and node.howmuch not in (0, 1): + _, node.howmuch = coerce_constant_value(datatype_of(node.target, node.my_scope()), node.howmuch, node.sourceref) + elif isinstance(node, Assignment): + lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes) + if len(lvalue_types) == 1: + _, node.right = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) + else: + for lv_dt in lvalue_types: + coerce_constant_value(lv_dt, node.right, node.sourceref) def create_multiassigns(self, module: Module) -> None: # create multi-assign statements from nested assignments (A=B=C=5), # and optimize TargetRegisters down to single Register if it's just one register. + # XXX done: converted to new nodes def reduce_right(assign: Assignment) -> Assignment: if isinstance(assign.right, Assignment): right = reduce_right(assign.right) @@ -184,12 +191,10 @@ class PlyParser: assign.right = right.right return assign - for block, parent in module.all_scopes(): - for node in block.nodes: # type: ignore - if isinstance(node, Assignment): - if isinstance(node.right, Assignment): - multi = reduce_right(node) - assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment) + for node in module.all_nodes([Assignment]): + if isinstance(node.right, Assignment): + multi = reduce_right(node) + assert multi is node and len(multi.left) > 1 and not isinstance(multi.right, Assignment) def apply_directive_options(self, module: Module) -> None: def set_save_registers(scope: Scope, save_dir: Directive) -> None: @@ -284,7 +289,7 @@ class PlyParser: self._get_subroutine_usages_from_assignment(module.subroutine_usage, node, block.scope) print("----------SUBROUTINES IN USE-------------") # XXX import pprint - pprint.pprint(module.subroutine_usage) # XXX + pprint.pprint(module.subroutine_usage) # XXX print("----------/SUBROUTINES IN USE-------------") # XXX def _get_subroutine_usages_from_subcall(self, usages: Dict[Tuple[str, str], Set[str]], @@ -307,7 +312,7 @@ class PlyParser: elif isinstance(expr, LiteralValue): return elif isinstance(expr, Dereference): - return self._get_subroutine_usages_from_expression(usages, expr.location, parent_scope) + return self._get_subroutine_usages_from_expression(usages, expr.operand, parent_scope) elif isinstance(expr, AddressOf): return self._get_subroutine_usages_from_expression(usages, expr.name, parent_scope) elif isinstance(expr, SymbolName): @@ -365,34 +370,31 @@ class PlyParser: usages[(namespace, symbol.name)].add(str(asmnode.sourceref)) def check_directives(self, module: Module) -> None: - for node, parent in module.all_scopes(): - if isinstance(node, Module): - # check module-level directives - imports = set() # type: Set[str] - for directive in node.scope.filter_nodes(Directive): - if directive.name not in {"output", "zp", "address", "import", "saveregisters", "noreturn"}: - raise ParseError("invalid directive in module", directive.sourceref) - if directive.name == "import": - if imports & set(directive.args): - raise ParseError("duplicate import", directive.sourceref) - imports |= set(directive.args) - if isinstance(node, (Block, Subroutine)): - # check block and subroutine-level directives - first_node = True - if not node.scope: - continue - for sub_node in node.scope.nodes: - if isinstance(sub_node, Directive): - if sub_node.name not in {"asmbinary", "asminclude", "breakpoint", "saveregisters", "noreturn"}: - raise ParseError("invalid directive in " + node.__class__.__name__.lower(), sub_node.sourceref) - if sub_node.name == "saveregisters" and not first_node: - raise ParseError("saveregisters directive must be the first", sub_node.sourceref) - first_node = False + # XXX done: converted to new nodes + imports = set() # type: Set[str] + for node in module.all_nodes(): + if isinstance(node, Directive): + assert isinstance(node.parent, Scope) + if node.parent.level == "module": + if node.name not in {"output", "zp", "address", "import", "saveregisters", "noreturn"}: + raise ParseError("invalid directive in module", node.sourceref) + if node.name == "import": + if imports & set(node.args): + raise ParseError("duplicate import", node.sourceref) + imports |= set(node.args) + else: + if node.name not in {"asmbinary", "asminclude", "breakpoint", "saveregisters", "noreturn"}: + raise ParseError("invalid directive in " + node.parent.__class__.__name__.lower(), node.sourceref) + if node.name == "saveregisters": + # it should be the first node in the scope + if node.parent.nodes[0] is not node: + raise ParseError("saveregisters directive must be first in this scope", node.sourceref) def process_imports(self, module: Module) -> None: # (recursively) imports the modules + # XXX done: converted to new nodes imported = [] - for directive in module.scope.filter_nodes(Directive): + for directive in module.all_nodes([Directive]): if directive.name == "import": if len(directive.args) < 1: raise ParseError("missing argument(s) for import directive", directive.sourceref) @@ -404,7 +406,7 @@ class PlyParser: imported_module.scope.parent_scope = module.scope imported.append(imported_module) self.parse_errors += import_parse_errors - if not self.parsing_import: + if not self.imported_module: # compiler support library is always imported (in main parser) filename = self.find_import_file("il65lib", module.sourceref.file) if filename: @@ -414,13 +416,14 @@ class PlyParser: self.parse_errors += import_parse_errors else: raise FileNotFoundError("missing il65lib") - # append the imported module's contents (blocks) at the end of the current module - for imported_module in imported: - for block in imported_module.scope.filter_nodes(Block): - module.scope.add_node(block) + # XXX append the imported module's contents (blocks) at the end of the current module + # for block in (node for imported_module in imported + # for node in imported_module.scope.nodes + # if isinstance(node, Block)): + # module.scope.add_node(block) def import_file(self, filename: str) -> Tuple[Module, int]: - sub_parser = PlyParser(parsing_import=True) + sub_parser = PlyParser(imported_module=True) return sub_parser.parse_file(filename), sub_parser.parse_errors def find_import_file(self, modulename: str, sourcefile: str) -> Optional[str]: @@ -443,7 +446,7 @@ class PlyParser: out = sys.stdout if out.isatty(): print("\x1b[1m", file=out) - if self.parsing_import: + if self.imported_module: print("Error (in imported file):", str(exc), file=out) else: print("Error:", str(exc), file=out) diff --git a/il65/emit/generate.py b/il65/emit/generate.py index aa2609e7e..96822579e 100644 --- a/il65/emit/generate.py +++ b/il65/emit/generate.py @@ -10,7 +10,7 @@ import datetime from typing import TextIO, Callable from ..plylex import print_bold from ..plyparse import Module, Scope, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, AstNode, ZpOptions, \ - InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr + InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr, AssignmentTargets from . import CodeError, to_hex from .variables import generate_block_init, generate_block_vars from .assignment import generate_assignment, generate_aug_assignment @@ -190,15 +190,21 @@ class AssemblyGenerator: elif isinstance(stmt, Return): if stmt.value_A: reg = Register(name="A", sourceref=stmt.sourceref) # type: ignore - assignment = Assignment(left=[reg], right=stmt.value_A, sourceref=stmt.sourceref) # type: ignore + assignment = Assignment(sourceref=stmt.sourceref) # type: ignore + assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(stmt.value_A) generate_assignment(out, assignment) if stmt.value_X: reg = Register(name="X", sourceref=stmt.sourceref) # type: ignore - assignment = Assignment(left=[reg], right=stmt.value_X, sourceref=stmt.sourceref) # type: ignore + assignment = Assignment(sourceref=stmt.sourceref) # type: ignore + assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(stmt.value_X) generate_assignment(out, assignment) if stmt.value_Y: reg = Register(name="Y", sourceref=stmt.sourceref) # type: ignore - assignment = Assignment(left=[reg], right=stmt.value_Y, sourceref=stmt.sourceref) # type: ignore + assignment = Assignment(sourceref=stmt.sourceref) # type: ignore + assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(stmt.value_Y) generate_assignment(out, assignment) out("\vrts") elif isinstance(stmt, InlineAssembly): diff --git a/il65/main.py b/il65/main.py index 8bea70a30..b4eb12565 100644 --- a/il65/main.py +++ b/il65/main.py @@ -81,6 +81,7 @@ def main() -> None: print("\nParsing program source code.") parser = PlyParser() parsed_module = parser.parse_file(args.sourcefile) + raise SystemExit("First fix the parser to iterate all nodes in the new way.") # XXX if parsed_module: if args.nooptimize: print_bold("not optimizing the parse tree!") diff --git a/il65/optimize.py b/il65/optimize.py index 7fc17637a..7e4364e15 100644 --- a/il65/optimize.py +++ b/il65/optimize.py @@ -6,7 +6,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ from .plyparse import Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr,\ - datatype_of, coerce_constant_value + datatype_of, coerce_constant_value, AssignmentTargets from .plylex import print_warning, print_bold @@ -47,14 +47,17 @@ class Optimizer: block.scope.remove_node(assignment) if assignment.right >= 8 and assignment.operator in ("<<=", ">>="): print("{}: shifting result is always zero".format(assignment.sourceref)) - new_stmt = Assignment(left=[assignment.left], right=0, sourceref=assignment.sourceref) + new_stmt = Assignment(sourceref=assignment.sourceref) + new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref)) + new_stmt.nodes.append(0) block.scope.replace_node(assignment, new_stmt) if assignment.operator in ("+=", "-=") and 0 < assignment.right < 256: howmuch = assignment.right if howmuch not in (0, 1): _, howmuch = coerce_constant_value(datatype_of(assignment.left, block.scope), howmuch, assignment.sourceref) - new_stmt = IncrDecr(target=assignment.left, operator="++" if assignment.operator == "+=" else "--", + new_stmt = IncrDecr(operator="++" if assignment.operator == "+=" else "--", howmuch=howmuch, sourceref=assignment.sourceref) + new_stmt.target = assignment.left block.scope.replace_node(assignment, new_stmt) def combine_assignments_into_multi(self): diff --git a/il65/plyparse.py b/il65/plyparse.py index b3bd81bb0..e8443ade1 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -10,7 +10,7 @@ import builtins import inspect import enum from collections import defaultdict -from typing import Union, Generator, Tuple, List, Optional, Dict, Any, Iterable +from typing import Union, Generator, Tuple, Sequence, List, Optional, Dict, Any, no_type_check import attr from ply.yacc import yacc from .plylex import SourceRef, tokens, lexer, find_tok_column, print_warning @@ -55,9 +55,12 @@ class UndefinedSymbolError(LookupError): start = "start" -@attr.s(cmp=False, slots=True, frozen=False) +@attr.s(cmp=False, slots=True, frozen=False, repr=False) class AstNode: + # all ast nodes have: sourceref, parent, and nodes (=list of zero or more sub-nodes) sourceref = attr.ib(type=SourceRef) + parent = attr.ib(init=False, default=None) # will be hooked up later + nodes = attr.ib(type=list, init=False, default=attr.Factory(list)) # type: List['AstNode'] # when evaluating an expression, does it have to be a constant value?: processed_expr_must_be_constant = attr.ib(type=bool, init=False, default=False) @@ -65,49 +68,60 @@ class AstNode: def lineref(self) -> str: return "src l. " + str(self.sourceref.line) - def print_tree(self) -> None: - def tostr(node: AstNode, level: int) -> None: - if not isinstance(node, AstNode): - return - indent = " " * level - name = getattr(node, "name", "") - print(indent, node.__class__.__name__, repr(name)) - try: - variables = vars(node).items() - except TypeError: - return - for name, value in variables: - if isinstance(value, AstNode): - tostr(value, level + 1) - if isinstance(value, (list, tuple, set)): - if len(value) > 0: - elt = list(value)[0] - if isinstance(elt, AstNode) or name == "nodes": - print(indent, " >", name, "=") - for elt in value: - tostr(elt, level + 2) - tostr(self, 0) + def my_scope(self) -> 'Scope': + # returns the closest Scope in the ancestry of this node, or raises LookupError if no scope is found + scope = self + while scope: + if isinstance(scope, Scope): + return scope + scope = scope.parent + raise LookupError("no scope found in node ancestry") - def process_expressions(self, scope: 'Scope') -> None: + def all_nodes(self, nodetypes: Sequence['AstNode']=None) -> Generator['AstNode', None, None]: + if nodetypes is None: + nodett = AstNode + else: + nodett = tuple(nodetypes) # type: ignore + for node in self.nodes: + if isinstance(node, nodett): # type: ignore + yield node + for node in self.nodes: + if isinstance(node, AstNode): + yield from node.all_nodes(nodetypes) + + def remove_node(self, node: 'AstNode') -> None: + self.nodes.remove(node) + + def replace_node(self, oldnode: 'AstNode', newnode: 'AstNode') -> None: + assert isinstance(newnode, AstNode) + idx = self.nodes.index(oldnode) + self.nodes[idx] = newnode + + def add_node(self, newnode: 'AstNode', index: int = None) -> None: + assert isinstance(newnode, AstNode) + if index is None: + self.nodes.append(newnode) + else: + self.nodes.insert(index, newnode) + + def process_expressions(self, scope: 'Scope') -> None: # XXX remove, use all_nodes # process/simplify all expressions (constant folding etc) # this is implemented in node types that have expression(s) and that should act on this. pass - def verify_symbol_names(self, scope: 'Scope') -> None: - # check all SymbolNames to see if they exist. - # this is implemented in node types that have expression(s) and that should act on this. - pass - -@attr.s(cmp=False, repr=False) +@attr.s(cmp=False) class Directive(AstNode): name = attr.ib(type=str) args = attr.ib(type=list, default=attr.Factory(list)) + # no subnodes. @attr.s(cmp=False, slots=True, repr=False) class Scope(AstNode): - nodes = attr.ib(type=list) + # has zero or more subnodes + level = attr.ib(type=str, init=True) + nodes = attr.ib(type=list, init=True) # requires nodes in __init__ symbols = attr.ib(init=False) name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc. parent_scope = attr.ib(init=False, default=None) # will be wired up later @@ -162,32 +176,21 @@ class Scope(AstNode): return self.parent_scope.lookup(name) raise UndefinedSymbolError("undefined symbol: " + name) - def filter_nodes(self, nodetype) -> Generator[AstNode, None, None]: - for node in self.nodes: - if isinstance(node, nodetype): - yield node - def remove_node(self, node: AstNode) -> None: if hasattr(node, "name"): try: del self.symbols[node.name] # type: ignore except KeyError: pass - self.nodes.remove(node) + super().remove_node(node) def replace_node(self, oldnode: AstNode, newnode: AstNode) -> None: - assert isinstance(newnode, AstNode) - idx = self.nodes.index(oldnode) - self.nodes[idx] = newnode if hasattr(oldnode, "name"): del self.symbols[oldnode.name] # type: ignore + super().replace_node(oldnode, newnode) def add_node(self, newnode: AstNode, index: int=None) -> None: - assert isinstance(newnode, AstNode) - if index is None: - self.nodes.append(newnode) - else: - self.nodes.insert(index, newnode) + super().add_node(newnode, index) self._populate_symboltable(newnode) @@ -224,19 +227,20 @@ def dimensions_validator(obj: 'DatatypeNode', attrib: attr.Attribute, value: Lis @attr.s(cmp=False, repr=False) class Block(AstNode): - scope = attr.ib(type=Scope) + # has one subnode: the Scope. name = attr.ib(type=str, default=None) address = attr.ib(type=int, default=None, validator=validate_address) _unnamed_block_labels = {} # type: Dict[Block, str] - def __attrs_post_init__(self): - self.scope.name = self.name - @property - def nodes(self) -> Iterable[AstNode]: - if self.scope: - return self.scope.nodes - return [] + def scope(self) -> Scope: + return self.nodes[0] if self.nodes else None # type: ignore + + @scope.setter + def scope(self, scope: Scope) -> None: + self.nodes.clear() + self.nodes.append(scope) + scope.name = self.name @property def label(self) -> str: @@ -251,28 +255,18 @@ class Block(AstNode): @attr.s(cmp=False, repr=False) class Module(AstNode): + # has one subnode: the Scope. name = attr.ib(type=str) # filename - scope = attr.ib(type=Scope) subroutine_usage = attr.ib(type=defaultdict, init=False, default=attr.Factory(lambda: defaultdict(set))) # will be populated later format = attr.ib(type=ProgramFormat, init=False, default=ProgramFormat.PRG) # can be set via directive address = attr.ib(type=int, init=False, default=0xc000, validator=validate_address) # can be set via directive zp_options = attr.ib(type=ZpOptions, init=False, default=ZpOptions.NOCLOBBER) # can be set via directive @property - def nodes(self) -> Iterable[AstNode]: - if self.scope: - return self.scope.nodes - return [] - - def all_scopes(self) -> Generator[Tuple[AstNode, AstNode], None, None]: - # generator that recursively yields through the scopes (preorder traversal), yields (node, parent_node) tuples. - # it iterates of copies of the node collections, so it's okay to modify the scopes you iterate over. - yield self, None - for block in list(self.scope.filter_nodes(Block)): - yield block, self - for subroutine in list(block.scope.filter_nodes(Subroutine)): - yield subroutine, block + def scope(self) -> Scope: + return self.nodes[0] if self.nodes else None # type: ignore + @no_type_check def zeropage(self) -> Optional[Block]: # return the zeropage block (if defined) first_block = next(self.scope.filter_nodes(Block)) @@ -280,6 +274,7 @@ class Module(AstNode): return first_block return None + @no_type_check def main(self) -> Optional[Block]: # return the 'main' block (if defined) for block in self.scope.filter_nodes(Block): @@ -288,15 +283,17 @@ class Module(AstNode): return None -@attr.s(cmp=False, repr=False) +@attr.s(cmp=False) class Label(AstNode): name = attr.ib(type=str) + # no subnodes. -@attr.s(cmp=False, repr=False) +@attr.s(cmp=False, slots=True) class Register(AstNode): name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS)) datatype = attr.ib(type=DataType, init=False) + # no subnodes. def __attrs_post_init__(self): if self.name in REGISTER_BYTES: @@ -320,121 +317,255 @@ class Register(AstNode): return self.name < other.name -@attr.s(cmp=False, repr=False) +@attr.s(cmp=False) class PreserveRegs(AstNode): registers = attr.ib(type=str) + # no subnodes. -@attr.s(cmp=False, repr=False) -class Assignment(AstNode): - # can be single- or multi-assignment - left = attr.ib(type=list) # type: List[Union[str, TargetRegisters, Dereference]] - right = attr.ib() - - def process_expressions(self, scope: Scope) -> None: - self.right = process_expression(self.right, scope, self.right.sourceref) - - def verify_symbol_names(self, scope: Scope) -> None: - for lv in self.left: - if isinstance(lv, SymbolName): - check_symbol_definition(lv.name, scope, lv.sourceref) - elif isinstance(lv, Dereference): - if isinstance(lv.location, SymbolName): - check_symbol_definition(lv.location.name, scope, lv.location.sourceref) - # the symbols in the assignment rvalue are checked when its expression is processed. - - -@attr.s(cmp=False, repr=False) -class AugAssignment(AstNode): - left = attr.ib() - operator = attr.ib(type=str) - right = attr.ib() - - def process_expressions(self, scope: Scope) -> None: - self.right = process_expression(self.right, scope, self.right.sourceref) - - def verify_symbol_names(self, scope: Scope) -> None: - if isinstance(self.left, SymbolName): - check_symbol_definition(self.left.name, scope, self.left.sourceref) - # the symbols in the assignment rvalue are checked when its expression is processed. - - -@attr.s(cmp=False, repr=False) -class SubCall(AstNode): - target = attr.ib() - preserve_regs = attr.ib() - arguments = attr.ib() - - def __attrs_post_init__(self): - self.arguments = self.arguments or [] - - def process_expressions(self, scope: Scope) -> None: - for callarg in self.arguments: - assert isinstance(callarg, CallArgument) - callarg.process_expressions(scope) - - def verify_symbol_names(self, scope: Scope) -> None: - if isinstance(self.target.target, SymbolName): - check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref) - # the symbols in the subroutine's arguments are checked when their expression is processed. - - -@attr.s(cmp=False, repr=False) -class Return(AstNode): - value_A = attr.ib(default=None) - value_X = attr.ib(default=None) - value_Y = attr.ib(default=None) - - def process_expressions(self, scope: Scope) -> None: - if self.value_A is not None: - self.value_A = process_expression(self.value_A, scope, self.sourceref) - if isinstance(self.value_A, (int, float, str, bool)): - try: - _, self.value_A = coerce_constant_value(DataType.BYTE, self.value_A, self.sourceref) - except (OverflowError, TypeError) as x: - raise ParseError("first value (A): " + str(x), self.sourceref) from None - if self.value_X is not None: - self.value_X = process_expression(self.value_X, scope, self.sourceref) - if isinstance(self.value_X, (int, float, str, bool)): - try: - _, self.value_X = coerce_constant_value(DataType.BYTE, self.value_X, self.sourceref) - except (OverflowError, TypeError) as x: - raise ParseError("second value (X): " + str(x), self.sourceref) from None - if self.value_Y is not None: - self.value_Y = process_expression(self.value_Y, scope, self.sourceref) - if isinstance(self.value_Y, (int, float, str, bool)): - try: - _, self.value_Y = coerce_constant_value(DataType.BYTE, self.value_Y, self.sourceref) - except (OverflowError, TypeError) as x: - raise ParseError("third value (Y): " + str(x), self.sourceref) from None - - -@attr.s(cmp=False, repr=False) +@attr.s(cmp=False) class TargetRegisters(AstNode): - # This is a tuple of 1 or more registers. + # subnodes is is a list of 1 or more registers. # In it's multiple-register form it is only used to be able to parse # the result of a subroutine call such as A,X = sub(). # It will be replaced by a regular Register node if it contains just one register. - registers = attr.ib(type=list) - - def add(self, register: str) -> None: - self.registers.append(register) + pass @attr.s(cmp=False, repr=False) class InlineAssembly(AstNode): + # no subnodes. assembly = attr.ib(type=str) -@attr.s(cmp=False, repr=True, slots=True) +@attr.s(cmp=False, slots=True) +class DatatypeNode(AstNode): + # no subnodes. + name = attr.ib(type=str) + dimensions = attr.ib(type=list, default=None, validator=dimensions_validator) # if set, 1 or more dimensions (ints) + + def to_enum(self): + return { + "byte": DataType.BYTE, + "word": DataType.WORD, + "float": DataType.FLOAT, + "text": DataType.STRING, + "ptext": DataType.STRING_P, + "stext": DataType.STRING_S, + "pstext": DataType.STRING_PS, + "matrix": DataType.MATRIX, + "array": DataType.BYTEARRAY, + "wordarray": DataType.WORDARRAY + }[self.name] + + +@attr.s(cmp=False, repr=False) +class Subroutine(AstNode): + # one subnode: the Scope. + name = attr.ib(type=str) + param_spec = attr.ib(type=list) + result_spec = attr.ib(type=list) + address = attr.ib(type=int, default=None, validator=validate_address) + + @property + def scope(self) -> Scope: + return self.nodes[0] if self.nodes else None # type: ignore + + @scope.setter + def scope(self, scope: Scope) -> None: + self.nodes.clear() + self.nodes.append(scope) + scope.name = self.name + if self.address is not None: + raise ValueError("subroutine must have either a scope or an address, not both") + + +@attr.s(cmp=False, repr=False) +class Goto(AstNode): + # one or two subnodes: target (SymbolName, int or Dereference) and optionally: condition (Expression) + if_stmt = attr.ib(default=None) + + def process_expressions(self, scope: Scope) -> None: + if len(self.nodes) == 2: + self.nodes[1] = process_expression(self.nodes[1], scope, self.nodes[1].sourceref) + + +@attr.s(cmp=True, slots=True) +class LiteralValue(AstNode): + # no subnodes. + value = attr.ib() + + +@attr.s(cmp=False) +class AddressOf(AstNode): + # no subnodes. + name = attr.ib(type=str) + + +@attr.s(cmp=False, slots=True) +class SymbolName(AstNode): + # no subnodes. + name = attr.ib(type=str) + + +@attr.s(cmp=False) +class Dereference(AstNode): + # one subnode: operand (SymbolName, int or register name) + datatype = attr.ib() + size = attr.ib(type=int, default=None) + + @property + def operand(self) -> Union[SymbolName, int, str]: + return self.nodes[0] # type: ignore + + def __attrs_post_init__(self): + # convert datatype node to enum + size + if self.datatype is None: + assert self.size is None + self.size = 1 + self.datatype = DataType.BYTE + elif isinstance(self.datatype, DatatypeNode): + assert self.size is None + self.size = self.datatype.dimensions + if not self.datatype.to_enum().isnumeric(): + raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref) + self.datatype = self.datatype.to_enum() + + +@attr.s(cmp=False) +class IncrDecr(AstNode): + # increment or decrement something by a CONSTANT value (1 or more) + # one subnode: target (TargetRegisters, Register, SymbolName, or Dereference). + operator = attr.ib(type=str, validator=attr.validators.in_(["++", "--"])) + howmuch = attr.ib(default=1) + + @property + def target(self) -> Union[TargetRegisters, Register, SymbolName, Dereference]: + return self.nodes[0] # type: ignore + + @target.setter + def target(self, target: Union[TargetRegisters, Register, SymbolName, Dereference]) -> None: + if isinstance(target, Register): + if target.name not in REGISTER_BYTES | REGISTER_WORDS: + raise ParseError("cannot incr/decr that register", self.sourceref) + if isinstance(target, TargetRegisters): + raise ParseError("cannot incr/decr multiple registers at once", self.sourceref) + self.nodes.clear() + self.nodes.append(target) + + def __attrs_post_init__(self): + # make sure the amount is always >= 0 + if self.howmuch < 0: + self.howmuch = -self.howmuch + self.operator = "++" if self.operator == "--" else "--" + + +@attr.s(cmp=False, slots=True, repr=False) +class Expression(AstNode): + left = attr.ib() + operator = attr.ib(type=str) + right = attr.ib() + unary = attr.ib(type=bool, default=False) + + def __attrs_post_init__(self): + assert self.operator not in ("++", "--"), "incr/decr should not be an expression" + + def process_expressions(self, scope: Scope) -> None: + raise RuntimeError("must be done via parent node's process_expressions") + + def evaluate_primitive_constants(self, scope: Scope) -> LiteralValue: + # make sure the lvalue and rvalue are primitives, and the operator is allowed + assert isinstance(self.left, LiteralValue) + assert isinstance(self.right, LiteralValue) + if self.operator not in {'+', '-', '*', '/', '//', '~', '<', '>', '<=', '>=', '==', '!='}: + raise ValueError("operator", self) + estr = "{} {} {}".format(repr(self.left.value), self.operator, repr(self.right.value)) + try: + return eval(estr, {}, {}) # safe because of checks above + except Exception as x: + raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None + + def print_tree(self) -> None: + def tree(expr: Any, level: int) -> str: + indent = " "*level + if not isinstance(expr, Expression): + return indent + str(expr) + "\n" + if expr.unary: + return indent + "{}{}".format(expr.operator, tree(expr.left, level+1)) + else: + return indent + "{}".format(tree(expr.left, level+1)) + \ + indent + str(expr.operator) + "\n" + \ + indent + "{}".format(tree(expr.right, level + 1)) + print(tree(self, 0)) + + +@attr.s(cmp=False, slots=True) +class CallArgument(AstNode): + # one subnode: the value (Expression) + name = attr.ib(type=str, default=None) + + @property + def value(self) -> Expression: + return self.nodes[0] # type: ignore + + def process_expressions(self, scope: Scope) -> None: + self.nodes[0] = process_expression(self.nodes[0], scope, self.sourceref) + + +@attr.s(cmp=False) +class CallArguments(AstNode): + # subnodes are zero or more subroutine call arguments (CallArgument) + nodes = attr.ib(type=list, init=True) # requires nodes in __init__ + + +@attr.s(cmp=False, repr=False) +class SubCall(AstNode): + # has three subnodes: + # 0: target (Symbolname, int, or Dereference), + # 1: preserve_regs (PreserveRegs) + # 2: arguments (CallArguments). + + @property + def target(self) -> Union[SymbolName, int, Dereference]: + return self.nodes[0] # type: ignore + + @property + def preserve_regs(self) -> PreserveRegs: + return self.nodes[1] # type: ignore + + @property + def arguments(self) -> CallArguments: + return self.nodes[2] # type: ignore + + def process_expressions(self, scope: Scope) -> None: + for callarg in self.nodes[2].nodes: + assert isinstance(callarg, CallArgument) + callarg.process_expressions(scope) + + +@attr.s(cmp=False, slots=True, repr=False) class VarDef(AstNode): + # zero or one subnode: value (an Expression). name = attr.ib(type=str) vartype = attr.ib() datatype = attr.ib() - value = attr.ib(default=None) size = attr.ib(type=list, default=None) zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later + @property + def value(self) -> Expression: + return self.nodes[0] if self.nodes else None # type: ignore + + @value.setter + def value(self, newvalue: Expression) -> None: + if self.nodes: + self.nodes[0] = newvalue + else: + self.nodes.append(newvalue) + # if the value is an expression, mark it as a *constant* expression here + if isinstance(self.nodes[0], AstNode): # XXX expression only? + self.value.processed_expr_must_be_constant = True + def __attrs_post_init__(self): # convert vartype to enum if self.vartype == "const": @@ -456,13 +587,7 @@ class VarDef(AstNode): self.datatype = self.datatype.to_enum() if self.datatype.isarray() and sum(self.size) in (0, 1): print("warning: {}: array/matrix with size 1, use normal byte/word instead for efficiency".format(self.sourceref)) - if self.vartype == VarType.CONST and self.value is None: - raise ParseError("constant value assignment is missing", - attr.evolve(self.sourceref, column=self.sourceref.column+len(self.name))) - # if the value is an expression, mark it as a *constant* expression here - if isinstance(self.value, AstNode): - self.value.processed_expr_must_be_constant = True - elif self.value is None and (self.datatype.isnumeric() or self.datatype.isarray()): + if self.value is None and (self.datatype.isnumeric() or self.datatype.isarray()): self.value = 0 # if it's a matrix with interleave, it must be memory mapped if self.datatype == DataType.MATRIX and len(self.size) == 3: @@ -482,179 +607,87 @@ class VarDef(AstNode): raise ParseError("processed expression vor vardef is not a constant value: " + str(x), self.sourceref) from None -@attr.s(cmp=False, slots=True, repr=False) -class DatatypeNode(AstNode): - name = attr.ib(type=str) - dimensions = attr.ib(type=list, default=None, validator=dimensions_validator) # if set, 1 or more dimensions (ints) - - def to_enum(self): - return { - "byte": DataType.BYTE, - "word": DataType.WORD, - "float": DataType.FLOAT, - "text": DataType.STRING, - "ptext": DataType.STRING_P, - "stext": DataType.STRING_S, - "pstext": DataType.STRING_PS, - "matrix": DataType.MATRIX, - "array": DataType.BYTEARRAY, - "wordarray": DataType.WORDARRAY - }[self.name] - - @attr.s(cmp=False, repr=False) -class Subroutine(AstNode): - name = attr.ib(type=str) - param_spec = attr.ib(type=list) - result_spec = attr.ib(type=list) - scope = attr.ib(type=Scope, default=None) - address = attr.ib(type=int, default=None, validator=validate_address) +class Return(AstNode): + # one, two or three subnodes: value_A, value_X, value_Y (all three Expression) + @property + def value_A(self) -> Expression: + return self.nodes[0] # type: ignore @property - def nodes(self) -> Iterable[AstNode]: - if self.scope: - return self.scope.nodes - return [] + def value_X(self) -> Expression: + return self.nodes[0] # type: ignore - def __attrs_post_init__(self): - if self.scope and self.address is not None: - raise ValueError("subroutine must have either a scope or an address, not both") - if self.scope: - self.scope.name = self.name - - -@attr.s(cmp=False, repr=False) -class Goto(AstNode): - target = attr.ib() - if_stmt = attr.ib(default=None) - condition = attr.ib(default=None) + @property + def value_Y(self) -> Expression: + return self.nodes[0] # type: ignore def process_expressions(self, scope: Scope) -> None: - if self.condition is not None: - self.condition = process_expression(self.condition, scope, self.condition.sourceref) - - def verify_symbol_names(self, scope: Scope) -> None: - if isinstance(self.target.target, SymbolName): - check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref) - - -@attr.s(cmp=False, repr=False) -class Dereference(AstNode): - location = attr.ib() - datatype = attr.ib() - size = attr.ib(type=int, default=None) - - def __attrs_post_init__(self): - # convert datatype node to enum + size - if self.datatype is None: - assert self.size is None - self.size = 1 - self.datatype = DataType.BYTE - elif isinstance(self.datatype, DatatypeNode): - assert self.size is None - self.size = self.datatype.dimensions - if not self.datatype.to_enum().isnumeric(): - raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref) - self.datatype = self.datatype.to_enum() - - def verify_symbol_names(self, scope: Scope) -> None: - print("DEREF", self.location) # XXX not called????? - if isinstance(self.location, SymbolName): - check_symbol_definition(self.location.name, scope, self.location.sourceref) - - -@attr.s(cmp=False, repr=False) -class LiteralValue(AstNode): - value = attr.ib() - - -@attr.s(cmp=False, repr=False) -class AddressOf(AstNode): - name = attr.ib(type=str) - - -@attr.s(cmp=False, repr=False) -class IncrDecr(AstNode): - # increment or decrement something by a CONSTANT value (1 or more) - target = attr.ib() - operator = attr.ib(type=str, validator=attr.validators.in_(["++", "--"])) - howmuch = attr.ib(default=1) - - def __attrs_post_init__(self): - # make sure the amount is always >= 0 - if self.howmuch < 0: - self.howmuch = -self.howmuch - self.operator = "++" if self.operator == "--" else "--" - if isinstance(self.target, Register): - if self.target.name not in REGISTER_BYTES | REGISTER_WORDS: - raise ParseError("cannot incr/decr that register", self.sourceref) - if isinstance(self.target, TargetRegisters): - raise ParseError("cannot incr/decr multiple registers at once", self.sourceref) - - def verify_symbol_names(self, scope: Scope) -> None: - if isinstance(self.target, SymbolName): - check_symbol_definition(self.target.name, scope, self.target.sourceref) - - -@attr.s(cmp=False, repr=False) -class SymbolName(AstNode): - name = attr.ib(type=str) + if self.nodes[0] is not None: + self.nodes[0] = process_expression(self.nodes[0], scope, self.sourceref) + if isinstance(self.nodes[0], (int, float, str, bool)): + try: + _, self.nodes[0] = coerce_constant_value(DataType.BYTE, self.nodes[0], self.sourceref) + except (OverflowError, TypeError) as x: + raise ParseError("first value (A): " + str(x), self.sourceref) from None + if self.nodes[1] is not None: + self.nodes[1] = process_expression(self.nodes[1], scope, self.sourceref) + if isinstance(self.nodes[1], (int, float, str, bool)): + try: + _, self.nodes[1] = coerce_constant_value(DataType.BYTE, self.nodes[1], self.sourceref) + except (OverflowError, TypeError) as x: + raise ParseError("second value (X): " + str(x), self.sourceref) from None + if self.nodes[2] is not None: + self.nodes[2] = process_expression(self.nodes[2], scope, self.sourceref) + if isinstance(self.nodes[2], (int, float, str, bool)): + try: + _, self.nodes[2] = coerce_constant_value(DataType.BYTE, self.nodes[2], self.sourceref) + except (OverflowError, TypeError) as x: + raise ParseError("third value (Y): " + str(x), self.sourceref) from None @attr.s(cmp=False, slots=True, repr=False) -class CallTarget(AstNode): - target = attr.ib() - address_of = attr.ib(type=bool) +class AssignmentTargets(AstNode): + # a list of one or more assignment targets (TargetRegisters, Register, SymbolName, or Dereference). + nodes = attr.ib(type=list, init=True) # requires nodes in __init__ @attr.s(cmp=False, slots=True, repr=False) -class CallArgument(AstNode): - value = attr.ib() - name = attr.ib(type=str, default=None) +class Assignment(AstNode): + # can be single- or multi-assignment + # has two subnodes: left (=AssignmentTargets) and right (=Expression or another Assignment but those will be converted to multi assign) + + @property + def left(self) -> AssignmentTargets: + return self.nodes[0] # type: ignore + + @property + def right(self) -> Union[LiteralValue, Expression]: + return self.nodes[1] # type: ignore + + @right.setter + def right(self, rvalue: Union[LiteralValue, Expression]) -> None: + self.nodes[1] = rvalue def process_expressions(self, scope: Scope) -> None: - self.value = process_expression(self.value, scope, self.sourceref) + self.nodes[1] = process_expression(self.nodes[1], scope, self.nodes[1].sourceref) @attr.s(cmp=False, slots=True, repr=False) -class Expression(AstNode): - left = attr.ib() +class AugAssignment(AstNode): + # has two subnodes: left (=TargetRegisters, Register, SymbolName, or Dereference) and right (=Expression) operator = attr.ib(type=str) - right = attr.ib() - unary = attr.ib(type=bool, default=False) - def __attrs_post_init__(self): - assert self.operator not in ("++", "--"), "incr/decr should not be an expression" + @property + def left(self) -> Union[TargetRegisters, Register, SymbolName, Dereference]: + return self.nodes[0] # type: ignore + + @property + def right(self) -> Expression: + return self.nodes[1] # type: ignore def process_expressions(self, scope: Scope) -> None: - raise RuntimeError("must be done via parent node's process_expressions") - - def evaluate_primitive_constants(self, scope: Scope) -> Union[int, float, str, bool]: - # make sure the lvalue and rvalue are primitives, and the operator is allowed - if not isinstance(self.left, (LiteralValue, int, float, str, bool)): - raise TypeError("left", self) - if not isinstance(self.right, (LiteralValue, int, float, str, bool)): - raise TypeError("right", self) - if self.operator not in {'+', '-', '*', '/', '//', '~', '<', '>', '<=', '>=', '==', '!='}: - raise ValueError("operator", self) - estr = "{} {} {}".format(repr(self.left), self.operator, repr(self.right)) - try: - return eval(estr, {}, {}) # safe because of checks above - except Exception as x: - raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None - - def print_tree(self) -> None: - def tree(expr: Any, level: int) -> str: - indent = " "*level - if not isinstance(expr, Expression): - return indent + str(expr) + "\n" - if expr.unary: - return indent + "{}{}".format(expr.operator, tree(expr.left, level+1)) - else: - return indent + "{}".format(tree(expr.left, level+1)) + \ - indent + str(expr.operator) + "\n" + \ - indent + "{}".format(tree(expr.right, level + 1)) - print(tree(self, 0)) + self.nodes[1] = process_expression(self.nodes[1], scope, self.right.sourceref) def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType: @@ -666,14 +699,16 @@ def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType: if isinstance(symdef, VarDef): return symdef.datatype elif isinstance(assignmenttarget, TargetRegisters): - if len(assignmenttarget.registers) == 1: - return datatype_of(assignmenttarget.registers[0], scope) + if len(assignmenttarget.nodes) == 1: + return datatype_of(assignmenttarget.nodes[0], scope) raise TypeError("cannot determine datatype", assignmenttarget) -def coerce_constant_value(datatype: DataType, value: Any, - sourceref: SourceRef=None) -> Tuple[bool, Any]: +def coerce_constant_value(datatype: DataType, value: AstNode, + sourceref: SourceRef=None) -> Tuple[bool, AstNode]: # if we're a BYTE type, and the value is a single character, convert it to the numeric value + assert isinstance(value, AstNode) + def verify_bounds(value: Union[int, float, str]) -> None: # if the value is out of bounds, raise an overflow exception if isinstance(value, (int, float)): @@ -683,26 +718,34 @@ def coerce_constant_value(datatype: DataType, value: Any, raise OverflowError("value out of range for word: " + str(value)) if datatype == DataType.FLOAT and not (FLOAT_MAX_NEGATIVE <= value <= FLOAT_MAX_POSITIVE): # type: ignore raise OverflowError("value out of range for float: " + str(value)) - if isinstance(value, str) and len(value) == 1 and (datatype.isnumeric() or datatype.isarray()): - return True, char_to_bytevalue(value) - # if we're an integer value and the passed value is float, truncate it (and give a warning) - if datatype in (DataType.BYTE, DataType.WORD, DataType.MATRIX) and isinstance(value, float): - frac = math.modf(value) - if frac != 0: - print_warning("float value truncated ({} to datatype {})".format(value, datatype.name), sourceref=sourceref) - value = int(value) - verify_bounds(value) - return True, value - if isinstance(value, (int, float)): - verify_bounds(value) - if isinstance(value, (Expression, SubCall)): + + if isinstance(value, LiteralValue): + if type(value.value) is str and len(value.value) == 1 and (datatype.isnumeric() or datatype.isarray()): + # convert a string of length 1 to its numeric character value + return True, LiteralValue(value=char_to_bytevalue(value.value), sourceref=value.sourceref) # type: ignore + # if we're an integer value and the passed value is float, truncate it (and give a warning) + if datatype in (DataType.BYTE, DataType.WORD, DataType.MATRIX) and isinstance(value.value, float): + frac = math.modf(value.value) + if frac != 0: + print_warning("float value truncated ({} to datatype {})".format(value.value, datatype.name), sourceref=sourceref) + v2 = int(value.value) + verify_bounds(v2) + return True, LiteralValue(value=v2, sourceref=value.sourceref) # type: ignore + if type(value.value) in (int, float): + verify_bounds(value.value) + if datatype == DataType.WORD: + if type(value.value) not in (int, float, str): + raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref) + elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): + if type(value.value) not in (int, float): + raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref) + elif isinstance(value, (Expression, SubCall)): return False, value - elif datatype == DataType.WORD: - if not isinstance(value, (int, float, str, Dereference, Register, SymbolName, AddressOf)): - raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) - elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): - if not isinstance(value, (int, float, Dereference, Register, SymbolName)): - raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) + if datatype == DataType.WORD and not isinstance(value, (LiteralValue, Dereference, Register, SymbolName, AddressOf)): + raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) + elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT) \ + and not isinstance(value, (LiteralValue, Dereference, Register, SymbolName)): + raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) return False, value @@ -718,9 +761,11 @@ def process_expression(value: Any, scope: Scope, sourceref: SourceRef) -> Any: return process_dynamic_expression(value, sourceref, scope) -def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> Union[int, float, str, bool]: +def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> LiteralValue: # the expression must result in a single (constant) value (int, float, whatever) - if expr is None or isinstance(expr, (int, float, str, bool)): + if isinstance(expr, (int, float, str, bool)): + raise TypeError("expr node should not be a python primitive value", expr) + elif expr is None or isinstance(expr, LiteralValue): return expr elif isinstance(expr, LiteralValue): return expr.value @@ -733,7 +778,7 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc if isinstance(value, Expression): raise ExpressionEvaluationError("circular reference?", expr.sourceref) elif isinstance(value, (int, float, str, bool)): - return value + raise TypeError("symbol value node should not be a python primitive value", expr) else: raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref) elif isinstance(expr, AddressOf): @@ -741,7 +786,10 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc value = check_symbol_definition(expr.name.name, symbolscope, expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: - return value.value + if isinstance(value.value, LiteralValue): + return value.value + else: + raise ExpressionEvaluationError("constant literal value required", value.sourceref) if value.vartype == VarType.CONST: raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref) raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant" @@ -750,43 +798,39 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc raise ExpressionEvaluationError("constant address required, not {}" .format(value.__class__.__name__), expr.name.sourceref) elif isinstance(expr, SubCall): - if isinstance(expr.target, CallTarget): - target = expr.target.target - if isinstance(target, SymbolName): # 'function(1,2,3)' - funcname = target.name - if funcname in math_functions or funcname in builtin_functions: - func_args = [] - for a in (process_constant_expression(callarg.value, sourceref, symbolscope) for callarg in expr.arguments): - if isinstance(a, LiteralValue): - func_args.append(a.value) - else: - func_args.append(a) - func = math_functions.get(funcname, builtin_functions.get(funcname)) - try: - return func(*func_args) - except Exception as x: - raise ExpressionEvaluationError(str(x), expr.sourceref) - else: - raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) - elif isinstance(target, Dereference): # '[...](1,2,3)' - raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) - elif type(target) is int: # '64738()' - raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) + if isinstance(expr.target, SymbolName): # 'function(1,2,3)' + funcname = expr.target.name + if funcname in math_functions or funcname in builtin_functions: + func_args = [] + for a in (process_constant_expression(callarg.value, sourceref, symbolscope) for callarg in expr.arguments.nodes): + if isinstance(a, LiteralValue): + func_args.append(a.value) + else: + func_args.append(a) + func = math_functions.get(funcname, builtin_functions.get(funcname)) + try: + return func(*func_args) + except Exception as x: + raise ExpressionEvaluationError(str(x), expr.sourceref) else: - raise NotImplementedError("weird call target", target) + raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) + elif isinstance(expr.target, Dereference): # '[...](1,2,3)' + raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) + elif type(expr.target) is int: # '64738()' + raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) else: - raise ParseError("function name required, not {}".format(expr.target.__class__.__name__), expr.sourceref) + raise NotImplementedError("weird call target", expr.target) elif not isinstance(expr, Expression): raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) if expr.unary: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = process_constant_expression(expr.left, left_sourceref, symbolscope) - if isinstance(expr.left, (int, float)): + if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): try: if expr.operator == '-': - return -expr.left + return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore elif expr.operator == '~': - return ~expr.left # type: ignore + return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore elif expr.operator in ("++", "--"): raise ValueError("incr/decr should not be an expression") raise ValueError("invalid unary operator", expr.operator) @@ -798,14 +842,14 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc expr.left = process_constant_expression(expr.left, left_sourceref, symbolscope) right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = process_constant_expression(expr.right, right_sourceref, symbolscope) - if isinstance(expr.left, (LiteralValue, int, float, str, bool)): - if isinstance(expr.right, (LiteralValue, int, float, str, bool)): + if isinstance(expr.left, LiteralValue): + if isinstance(expr.right, LiteralValue): return expr.evaluate_primitive_constants(symbolscope) else: - raise ExpressionEvaluationError("constant value required on right, not {}" + raise ExpressionEvaluationError("constant literal value required on right, not {}" .format(expr.right.__class__.__name__), right_sourceref) else: - raise ExpressionEvaluationError("constant value required on left, not {}" + raise ExpressionEvaluationError("constant literal value required on left, not {}" .format(expr.left.__class__.__name__), left_sourceref) @@ -836,14 +880,14 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Sco try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: - if isinstance(expr.target.target, SymbolName): - check_symbol_definition(expr.target.target.name, symbolscope, expr.target.target.sourceref) + if isinstance(expr.target, SymbolName): + check_symbol_definition(expr.target.name, symbolscope, expr.target.sourceref) return expr elif isinstance(expr, Register): return expr elif isinstance(expr, Dereference): - if isinstance(expr.location, SymbolName): - check_symbol_definition(expr.location.name, symbolscope, expr.location.sourceref) + if isinstance(expr.operand, SymbolName): + check_symbol_definition(expr.operand.name, symbolscope, expr.operand.sourceref) return expr elif not isinstance(expr, Expression): raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) @@ -873,13 +917,15 @@ def p_start(p): | module_elements """ if p[1]: - scope = Scope(nodes=p[1], sourceref=_token_sref(p, 1)) + scope = Scope(nodes=p[1], level="module", sourceref=_token_sref(p, 1)) scope.name = "<" + p.lexer.source_filename + " global scope>" - p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=_token_sref(p, 1)) + p[0] = Module(name=p.lexer.source_filename, sourceref=_token_sref(p, 1)) + p[0].nodes.append(scope) else: - scope = Scope(nodes=[], sourceref=_token_sref(p, 1)) + scope = Scope(nodes=[], level="module", sourceref=_token_sref(p, 1)) scope.name = "<" + p.lexer.source_filename + " global scope>" - p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=SourceRef(lexer.source_filename, 1, 1)) + p[0] = Module(name=p.lexer.source_filename, sourceref=SourceRef(lexer.source_filename, 1, 1)) + p[0].nodes.append(scope) def p_module(p): @@ -945,21 +991,24 @@ def p_block_name_addr(p): """ block : BITINVERT NAME INTEGER endl_opt scope """ - p[0] = Block(name=p[2], address=p[3], scope=p[5], sourceref=_token_sref(p, 2)) + p[0] = Block(name=p[2], address=p[3], sourceref=_token_sref(p, 2)) + p[0].scope = p[5] def p_block_name(p): """ block : BITINVERT NAME endl_opt scope """ - p[0] = Block(name=p[2], scope=p[4], sourceref=_token_sref(p, 2)) + p[0] = Block(name=p[2], sourceref=_token_sref(p, 2)) + p[0].scope = p[4] def p_block(p): """ block : BITINVERT endl_opt scope """ - p[0] = Block(scope=p[3], sourceref=_token_sref(p, 1)) + p[0] = Block(sourceref=_token_sref(p, 1)) + p[0].scope = p[3] def p_endl_opt(p): @@ -974,7 +1023,7 @@ def p_scope(p): """ scope : '{' scope_elements_opt '}' """ - p[0] = Scope(nodes=p[2] or [], sourceref=_token_sref(p, 1)) + p[0] = Scope(nodes=p[2] or [], level="block", sourceref=_token_sref(p, 1)) def p_scope_elements_opt(p): @@ -1040,7 +1089,8 @@ def p_vardef_value(p): """ vardef : VARTYPE type_opt NAME IS expression """ - p[0] = VarDef(name=p[3], vartype=p[1], datatype=p[2], value=p[5], sourceref=_token_sref(p, 3)) + p[0] = VarDef(name=p[3], vartype=p[1], datatype=p[2], sourceref=_token_sref(p, 3)) + p[0].value = p[5] def p_type_opt(p): @@ -1086,7 +1136,8 @@ def p_subroutine(p): """ body = p[10] if isinstance(body, Scope): - p[0] = Subroutine(name=p[2], param_spec=p[4] or [], result_spec=p[8] or [], scope=body, sourceref=_token_sref(p, 1)) + p[0] = Subroutine(name=p[2], param_spec=p[4] or [], result_spec=p[8] or [], sourceref=_token_sref(p, 1)) + p[0].scope = body elif type(body) is int: p[0] = Subroutine(name=p[2], param_spec=p[4] or [], result_spec=p[8] or [], address=body, sourceref=_token_sref(p, 1)) else: @@ -1183,14 +1234,18 @@ def p_incrdecr(p): incrdecr : assignment_target INCR | assignment_target DECR """ - p[0] = IncrDecr(target=p[1], operator=p[2], sourceref=_token_sref(p, 2)) + p[0] = IncrDecr(operator=p[2], sourceref=_token_sref(p, 2)) + p[0].target = p[1] def p_call_subroutine(p): """ subroutine_call : calltarget preserveregs_opt '(' call_arguments_opt ')' """ - p[0] = SubCall(target=p[1], preserve_regs=p[2], arguments=p[4], sourceref=_token_sref(p, 3)) + p[0] = SubCall(sourceref=_token_sref(p, 3)) + p[0].nodes.append(p[1]) + p[0].nodes.append(p[2]) + p[0].nodes.append(CallArguments(nodes=p[4] or [], sourceref=p[0].sourceref)) def p_preserveregs_opt(p): @@ -1234,9 +1289,11 @@ def p_call_argument(p): | NAME IS expression """ if len(p) == 2: - p[0] = CallArgument(value=p[1], sourceref=_token_sref(p, 1)) + p[0] = CallArgument(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[1]) elif len(p) == 4: - p[0] = CallArgument(name=p[1], value=p[3], sourceref=_token_sref(p, 1)) + p[0] = CallArgument(name=p[1], sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[3]) def p_return(p): @@ -1249,11 +1306,17 @@ def p_return(p): if len(p) == 2: p[0] = Return(sourceref=_token_sref(p, 1)) elif len(p) == 3: - p[0] = Return(value_A=p[2], sourceref=_token_sref(p, 1)) + p[0] = Return(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[2]) # A elif len(p) == 5: - p[0] = Return(value_A=p[2], value_X=p[4], sourceref=_token_sref(p, 1)) + p[0] = Return(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[2]) # A + p[0].nodes.append(p[4]) # X elif len(p) == 7: - p[0] = Return(value_A=p[2], value_X=p[4], value_Y=p[6], sourceref=_token_sref(p, 1)) + p[0] = Return(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[2]) # A + p[0].nodes.append(p[4]) # X + p[0].nodes.append(p[6]) # Y def p_register(p): @@ -1267,21 +1330,25 @@ def p_goto(p): """ goto : GOTO calltarget """ - p[0] = Goto(target=p[2], sourceref=_token_sref(p, 1)) + p[0] = Goto(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[2]) def p_conditional_goto_plain(p): """ conditional_goto : IF GOTO calltarget """ - p[0] = Goto(target=p[3], if_stmt=p[1], sourceref=_token_sref(p, 1)) + p[0] = Goto(if_stmt=p[1], sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[3]) def p_conditional_goto_expr(p): """ conditional_goto : IF expression GOTO calltarget """ - p[0] = Goto(target=p[4], if_stmt=p[1], condition=p[2], sourceref=_token_sref(p, 1)) + p[0] = Goto(if_stmt=p[1], sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[4]) + p[0].nodes.append(p[2]) def p_calltarget(p): @@ -1290,17 +1357,15 @@ def p_calltarget(p): | INTEGER | dereference """ - if len(p) == 2: - p[0] = CallTarget(target=p[1], address_of=False, sourceref=_token_sref(p, 1)) - elif len(p) == 3: - p[0] = CallTarget(target=p[2], address_of=True, sourceref=_token_sref(p, 1)) + p[0] = p[1] def p_dereference(p): """ dereference : '[' dereference_operand ']' """ - p[0] = Dereference(location=p[2][0], datatype=p[2][1], sourceref=_token_sref(p, 1)) + p[0] = Dereference(datatype=p[2][1], sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[2][0]) def p_dereference_operand(p): @@ -1325,14 +1390,18 @@ def p_assignment(p): assignment : assignment_target IS expression | assignment_target IS assignment """ - p[0] = Assignment(left=[p[1]], right=p[3], sourceref=_token_sref(p, 2)) + p[0] = Assignment(sourceref=_token_sref(p, 2)) + p[0].nodes.append(AssignmentTargets(nodes=[p[1]], sourceref=p[0].sourceref)) + p[0].nodes.append(p[3]) def p_aug_assignment(p): """ aug_assignment : assignment_target AUGASSIGN expression """ - p[0] = AugAssignment(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) + p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2)) + p[0].nodes.append(p[1]) + p[0].nodes.append(p[2]) precedence = ( @@ -1413,9 +1482,9 @@ def p_assignment_target(p): """ if isinstance(p[1], TargetRegisters): # if the target registers is just a single register, use that instead - if len(p[1].registers) == 1: - assert isinstance(p[1].registers[0], Register) - p[1] = p[1].registers[0] + if len(p[1].nodes) == 1: + assert isinstance(p[1].nodes[0], Register) + p[1] = p[1].nodes[0] p[0] = p[1] @@ -1425,9 +1494,10 @@ def p_target_registers(p): | target_registers ',' register """ if len(p) == 2: - p[0] = TargetRegisters(registers=[p[1]], sourceref=_token_sref(p, 1)) + p[0] = TargetRegisters(sourceref=_token_sref(p, 1)) + p[0].nodes.append(p[1]) else: - p[1].add(p[3]) + p[1].nodes.append(p[3]) p[0] = p[1] @@ -1485,6 +1555,13 @@ class TokenFilter: parser = yacc(write_tables=True) +def connect_parents(node: AstNode, parent: AstNode) -> None: + node.parent = parent + for childnode in node.nodes: + if isinstance(childnode, AstNode): + connect_parents(childnode, node) + + def parse_file(filename: str, lexer_error_func=None) -> Module: lexer.error_function = lexer_error_func lexer.lineno = 1 @@ -1492,4 +1569,6 @@ def parse_file(filename: str, lexer_error_func=None) -> Module: tfilter = TokenFilter(lexer) with open(filename, "rU") as inf: sourcecode = inf.read() - return parser.parse(input=sourcecode, tokenfunc=tfilter.token) + result = parser.parse(input=sourcecode, tokenfunc=tfilter.token) + connect_parents(result, None) + return result diff --git a/tests/test_core.py b/tests/test_core.py index b8e469a38..d034e8ad8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,6 @@ import pytest from il65 import datatypes -from il65.plyparse import coerce_constant_value +from il65.plyparse import coerce_constant_value, LiteralValue from il65.compile import ParseError from il65.plylex import SourceRef from il65.emit import to_hex, to_mflpt5 @@ -101,39 +101,42 @@ def test_char_to_bytevalue(): def test_coerce_value(): - assert coerce_constant_value(datatypes.DataType.BYTE, 0) == (False, 0) - assert coerce_constant_value(datatypes.DataType.BYTE, 255) == (False, 255) - assert coerce_constant_value(datatypes.DataType.BYTE, '@') == (True, 64) - assert coerce_constant_value(datatypes.DataType.WORD, 0) == (False, 0) - assert coerce_constant_value(datatypes.DataType.WORD, 65535) == (False, 65535) - assert coerce_constant_value(datatypes.DataType.WORD, '@') == (True, 64) - assert coerce_constant_value(datatypes.DataType.FLOAT, -999.22) == (False, -999.22) - assert coerce_constant_value(datatypes.DataType.FLOAT, 123.45) == (False, 123.45) - assert coerce_constant_value(datatypes.DataType.FLOAT, '@') == (True, 64) - assert coerce_constant_value(datatypes.DataType.BYTE, 5.678) == (True, 5) - assert coerce_constant_value(datatypes.DataType.WORD, 5.678) == (True, 5) - assert coerce_constant_value(datatypes.DataType.WORD, "string") == (False, "string"), "string (address) can be assigned to a word" - assert coerce_constant_value(datatypes.DataType.STRING, "string") == (False, "string") - assert coerce_constant_value(datatypes.DataType.STRING_P, "string") == (False, "string") - assert coerce_constant_value(datatypes.DataType.STRING_S, "string") == (False, "string") - assert coerce_constant_value(datatypes.DataType.STRING_PS, "string") == (False, "string") + def lv(v) -> LiteralValue: + return LiteralValue(value=v, sourceref=SourceRef("test", 1, 1)) + assert coerce_constant_value(datatypes.DataType.BYTE, lv(0)) == (False, lv(0)) + assert coerce_constant_value(datatypes.DataType.BYTE, lv(255)) == (False, lv(255)) + assert coerce_constant_value(datatypes.DataType.BYTE, lv('@')) == (True, lv(64)) + assert coerce_constant_value(datatypes.DataType.WORD, lv(0)) == (False, lv(0)) + assert coerce_constant_value(datatypes.DataType.WORD, lv(65535)) == (False, lv(65535)) + assert coerce_constant_value(datatypes.DataType.WORD, lv('@')) == (True, lv(64)) + assert coerce_constant_value(datatypes.DataType.FLOAT, lv(-999.22)) == (False, lv(-999.22)) + assert coerce_constant_value(datatypes.DataType.FLOAT, lv(123.45)) == (False, lv(123.45)) + assert coerce_constant_value(datatypes.DataType.FLOAT, lv('@')) == (True, lv(64)) + assert coerce_constant_value(datatypes.DataType.BYTE, lv(5.678)) == (True, lv(5)) + assert coerce_constant_value(datatypes.DataType.WORD, lv(5.678)) == (True, lv(5)) + assert coerce_constant_value(datatypes.DataType.WORD, + lv("string")) == (False, lv("string")), "string (address) can be assigned to a word" + assert coerce_constant_value(datatypes.DataType.STRING, lv("string")) == (False, lv("string")) + assert coerce_constant_value(datatypes.DataType.STRING_P, lv("string")) == (False, lv("string")) + assert coerce_constant_value(datatypes.DataType.STRING_S, lv("string")) == (False, lv("string")) + assert coerce_constant_value(datatypes.DataType.STRING_PS, lv("string")) == (False, lv("string")) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.BYTE, -1) + coerce_constant_value(datatypes.DataType.BYTE, lv(-1)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.BYTE, 256) + coerce_constant_value(datatypes.DataType.BYTE, lv(256)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.BYTE, 256.12345) + coerce_constant_value(datatypes.DataType.BYTE, lv(256.12345)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.WORD, -1) + coerce_constant_value(datatypes.DataType.WORD, lv(-1)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.WORD, 65536) + coerce_constant_value(datatypes.DataType.WORD, lv(65536)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.WORD, 65536.12345) + coerce_constant_value(datatypes.DataType.WORD, lv(65536.12345)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.FLOAT, -1.7014118346e+38) + coerce_constant_value(datatypes.DataType.FLOAT, lv(-1.7014118346e+38)) with pytest.raises(OverflowError): - coerce_constant_value(datatypes.DataType.FLOAT, 1.7014118347e+38) + coerce_constant_value(datatypes.DataType.FLOAT, lv(1.7014118347e+38)) with pytest.raises(TypeError): - coerce_constant_value(datatypes.DataType.BYTE, "string") + coerce_constant_value(datatypes.DataType.BYTE, lv("string")) with pytest.raises(TypeError): - coerce_constant_value(datatypes.DataType.FLOAT, "string") + coerce_constant_value(datatypes.DataType.FLOAT, lv("string")) diff --git a/tests/test_parser.py b/tests/test_parser.py index 9f7a9d80d..932d8a595 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,12 +1,13 @@ from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef -from il65.plyparse import parser, TokenFilter, Module, Subroutine, Block, Return, Scope, \ - VarDef, Expression, LiteralValue, Label, SubCall, CallTarget, SymbolName, Dereference -from il65.datatypes import DataType, char_to_bytevalue +from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, Return, Scope, \ + VarDef, Expression, LiteralValue, Label, SubCall, Dereference +from il65.datatypes import DataType def lexer_error(sourceref: SourceRef, fmtstring: str, *args: str) -> None: print("ERROR: {}: {}".format(sourceref, fmtstring.format(*args))) + lexer.error_function = lexer_error @@ -112,6 +113,7 @@ def test_parser(): lexer.source_filename = "sourcefile" filter = TokenFilter(lexer) result = parser.parse(input=test_source_1, tokenfunc=filter.token) + connect_parents(result, None) assert isinstance(result, Module) assert result.name == "sourcefile" assert result.scope.name == "" @@ -123,7 +125,6 @@ def test_parser(): block = result.scope.lookup("block") assert isinstance(block, Block) assert block.name == "block" - assert block.nodes is block.scope.nodes bool_vdef = block.scope.nodes[1] assert isinstance(bool_vdef, VarDef) assert isinstance(bool_vdef.value, Expression) @@ -134,30 +135,26 @@ def test_parser(): sub2 = block.scope.lookup("calculate") assert sub2 is sub assert sub2.lineref == "src l. 19" - all_scopes = list(result.all_scopes()) - assert len(all_scopes) == 3 - assert isinstance(all_scopes[0][0], Module) - assert all_scopes[0][1] is None - assert isinstance(all_scopes[1][0], Block) - assert isinstance(all_scopes[1][1], Module) - assert isinstance(all_scopes[2][0], Subroutine) - assert isinstance(all_scopes[2][1], Block) - stmt = list(all_scopes[2][0].scope.filter_nodes(Return)) - assert len(stmt) == 1 - assert isinstance(stmt[0], Return) - assert stmt[0].lineref == "src l. 20" + all_nodes = list(result.all_nodes()) + assert len(all_nodes) == 12 + all_nodes = list(result.all_nodes([Subroutine])) + assert len(all_nodes) == 1 + assert isinstance(all_nodes[0], Subroutine) + assert isinstance(all_nodes[0].parent, Scope) + assert all_nodes[0] in all_nodes[0].parent.nodes + assert all_nodes[0].lineref == "src l. 19" + assert all_nodes[0].parent.lineref == "src l. 8" def test_block_nodes(): sref = SourceRef("file", 1, 1) sub1 = Subroutine(name="subaddr", param_spec=[], result_spec=[], address=0xc000, sourceref=sref) - sub2 = Subroutine(name="subblock", param_spec=[], result_spec=[], - scope=Scope(nodes=[Label(name="start", sourceref=sref)], sourceref=sref), sourceref=sref) + sub2 = Subroutine(name="subblock", param_spec=[], result_spec=[], sourceref=sref) + sub2.scope = Scope(nodes=[Label(name="start", sourceref=sref)], level="block", sourceref=sref) assert sub1.scope is None assert sub1.nodes == [] assert sub2.scope is not None assert len(sub2.scope.nodes) > 0 - assert sub2.nodes is sub2.scope.nodes test_source_2 = """ @@ -173,20 +170,18 @@ def test_parser_2(): lexer.source_filename = "sourcefile" filter = TokenFilter(lexer) result = parser.parse(input=test_source_2, tokenfunc=filter.token) - block = result.nodes[0] - call = block.nodes[0] + connect_parents(result, None) + block = result.scope.nodes[0] + call = block.scope.nodes[0] assert isinstance(call, SubCall) - assert len(call.arguments) == 2 - assert isinstance(call.target, CallTarget) - assert call.target.target == 999 - assert call.target.address_of is False - call = block.nodes[1] + assert len(call.arguments.nodes) == 2 + assert isinstance(call.target, int) + assert call.target == 999 + call = block.scope.nodes[1] assert isinstance(call, SubCall) - assert len(call.arguments) == 0 - assert isinstance(call.target, CallTarget) - assert isinstance(call.target.target, Dereference) - assert call.target.target.location.name == "zz" - assert call.target.address_of is False + assert len(call.arguments.nodes) == 0 + assert isinstance(call.target, Dereference) + assert call.target.operand.name == "zz" test_source_3 = """ @@ -198,33 +193,35 @@ test_source_3 = """ } """ + def test_typespec(): lexer.lineno = 1 lexer.source_filename = "sourcefile" filter = TokenFilter(lexer) result = parser.parse(input=test_source_3, tokenfunc=filter.token) - nodes = result.nodes[0].nodes - assignment1, assignment2, assignment3, assignment4 = nodes + connect_parents(result, None) + block = result.scope.nodes[0] + assignment1, assignment2, assignment3, assignment4 = block.scope.nodes assert assignment1.right.value == 5 assert assignment2.right.value == 5 assert assignment3.right.value == 5 assert assignment4.right.value == 5 - assert len(assignment1.left) == 1 - assert len(assignment2.left) == 1 - assert len(assignment3.left) == 1 - assert len(assignment4.left) == 1 - t1 = assignment1.left[0] - t2 = assignment2.left[0] - t3 = assignment3.left[0] - t4 = assignment4.left[0] + assert len(assignment1.left.nodes) == 1 + assert len(assignment2.left.nodes) == 1 + assert len(assignment3.left.nodes) == 1 + assert len(assignment4.left.nodes) == 1 + t1 = assignment1.left.nodes[0] + t2 = assignment2.left.nodes[0] + t3 = assignment3.left.nodes[0] + t4 = assignment4.left.nodes[0] assert isinstance(t1, Dereference) assert isinstance(t2, Dereference) assert isinstance(t3, Dereference) assert isinstance(t4, Dereference) - assert t1.location == 0xc000 - assert t2.location == 0xc000 - assert t3.location == "AX" - assert t4.location == "AX" + assert t1.operand == 0xc000 + assert t2.operand == 0xc000 + assert t3.operand == "AX" + assert t4.operand == "AX" assert t1.datatype == DataType.WORD assert t2.datatype == DataType.BYTE assert t3.datatype == DataType.WORD @@ -252,8 +249,9 @@ def test_char_string(): lexer.source_filename = "sourcefile" filter = TokenFilter(lexer) result = parser.parse(input=test_source_4, tokenfunc=filter.token) - nodes = result.nodes[0].nodes - var1, var2, var3, assgn1, assgn2, assgn3, = nodes + connect_parents(result, None) + block = result.scope.nodes[0] + var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes assert var1.value.value == 64 assert var2.value.value == 126 assert var3.value.value == "abc" @@ -278,8 +276,9 @@ def test_boolean_int(): lexer.source_filename = "sourcefile" filter = TokenFilter(lexer) result = parser.parse(input=test_source_5, tokenfunc=filter.token) - nodes = result.nodes[0].nodes - var1, var2, assgn1, assgn2, = nodes + connect_parents(result, None) + block = result.scope.nodes[0] + var1, var2, assgn1, assgn2, = block.scope.nodes assert type(var1.value.value) is int and var1.value.value == 1 assert type(var2.value.value) is int and var2.value.value == 0 assert type(assgn1.right.value) is int and assgn1.right.value == 1 diff --git a/todo.ill b/todo.ill index b62b7363b..c6778ad09 100644 --- a/todo.ill +++ b/todo.ill @@ -3,7 +3,6 @@ ~ main { - var .byte v1t = true var .byte v1f = false var .word v2t = true