From 68c1d2af4c57bc58ef13b46c344829652a2cf6b6 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sun, 7 Jan 2018 19:14:21 +0100 Subject: [PATCH] comp --- il65/exprparse.py | 2 +- il65/lexer.py | 14 ++-- il65/parse.py | 6 +- il65/plycompiler.py | 153 +++++++++++++++++++++++++++++++++++++++++ il65/plyparser.py | 162 +++++++++++++++++++++++++++++++------------- il65/symbols.py | 23 ++++--- lib/mathlib.ill | 1 - reference.md | 2 +- todo.ill | 37 ++++++++++ 9 files changed, 330 insertions(+), 70 deletions(-) create mode 100644 il65/plycompiler.py diff --git a/il65/exprparse.py b/il65/exprparse.py index ef78f7f9a..b2e2dc43b 100644 --- a/il65/exprparse.py +++ b/il65/exprparse.py @@ -13,7 +13,7 @@ from .symbols import FLOAT_MAX_POSITIVE, FLOAT_MAX_NEGATIVE, SourceRef, SymbolTa class ParseError(Exception): - def __init__(self, message: str, sourcetext: str, sourceref: SourceRef) -> None: + def __init__(self, message: str, sourcetext: Optional[str], sourceref: SourceRef) -> None: self.sourceref = sourceref self.msg = message self.sourcetext = sourcetext diff --git a/il65/lexer.py b/il65/lexer.py index 23b4a5bf6..4da60b33f 100644 --- a/il65/lexer.py +++ b/il65/lexer.py @@ -9,6 +9,8 @@ License: GNU GPL 3.0, see LICENSE import sys import ply.lex from .symbols import SourceRef +from .parse import ParseError + # token names @@ -218,7 +220,7 @@ def t_NAME(t): def t_DIRECTIVE(t): - r"%[a-z]+" + r"%[a-z]+\b" t.value = t.value[1:] return t @@ -284,13 +286,15 @@ def t_PRESERVEREGS(t): def t_ENDL(t): r"\n+" t.lexer.lineno += len(t.value) + t.value = "\n" return t # end of lines are significant to the parser def t_error(t): line, col = t.lineno, find_tok_column(t) - sref = SourceRef("@todo-filename-f1", line, col) - t.lexer.error_function("{}: Illegal character '{:s}'", sref, t.value[0]) + filename = getattr(t.lexer, "source_filename", "") + sref = SourceRef(filename, line, col) + t.lexer.error_function("{}: Illegal character '{:s}'", sref, t.value[0], sourceref=sref) t.lexer.skip(1) @@ -300,8 +304,8 @@ def find_tok_column(token): return token.lexpos - last_cr -def error_function(fmtstring, *args): - print("ERROR:", fmtstring.format(*args), file=sys.stderr) +def error_function(fmtstring, *args, sourceref: SourceRef=None) -> None: + raise ParseError(fmtstring.format(*args), None, sourceref=sourceref) lexer = ply.lex.lex() diff --git a/il65/parse.py b/il65/parse.py index cabbad0a6..5db56f623 100644 --- a/il65/parse.py +++ b/il65/parse.py @@ -475,9 +475,9 @@ class Parser: if self.result.format == ProgramFormat.PRG and self.result.with_sys and self.result.start_address != 0x0801: raise self.PError("cannot use non-default 'address' when output format includes basic SYS program") continue - elif directive == "preserve_registers": + elif directive == "saveregisters": if preserve_specified: - raise self.PError("can only specify preserve_registers option once") + raise self.PError("can only specify saveregisters option once") preserve_specified = True _, _, optionstr = line.partition(" ") self.result.preserve_registers = optionstr in ("", "true", "yes") @@ -649,7 +649,7 @@ class Parser: elif directive == "breakpoint": self.cur_block.statements.append(BreakpointStmt(self.sourceref)) self.print_warning("breakpoint defined") - elif directive == "preserve_registers": + elif directive == "saveregisters": self.result.preserve_registers = optionstr in ("", "true", "yes") else: raise self.PError("invalid directive") diff --git a/il65/plycompiler.py b/il65/plycompiler.py new file mode 100644 index 000000000..e95daca10 --- /dev/null +++ b/il65/plycompiler.py @@ -0,0 +1,153 @@ +import os +import sys +import linecache +from typing import Optional, Generator, Tuple, Set +from .plyparser import parse_file, Module, Directive, Block, Subroutine, AstNode +from .parse import ParseError +from .symbols import SourceRef + + +class PlyParser: + def __init__(self): + self.parse_errors = 0 + self.parsing_import = False + + def parse_file(self, filename: str) -> Module: + print("parsing:", filename) + module = parse_file(filename) + try: + self.check_directives(module) + self.remove_empty_blocks(module) + self.process_imports(module) + except ParseError as x: + self.handle_parse_error(x) + return module + + def remove_empty_blocks(self, module: Module) -> None: + # remove blocks without name and without address, or that are empty + for scope, parent in self.recurse_scopes(module): + if isinstance(scope, (Subroutine, Block)): + if not scope.scope: + continue + if all(isinstance(n, Directive) for n in scope.scope.nodes): + empty = True + for n in scope.scope.nodes: + empty = empty and n.name not in {"asmbinary", "asminclude"} + if empty: + self.print_warning("ignoring empty block or subroutine", scope.sourceref) + assert isinstance(parent, (Block, Module)) + parent.scope.nodes.remove(scope) + if isinstance(scope, Block): + if not scope.name and scope.address is None: + self.print_warning("ignoring block without name and address", scope.sourceref) + assert isinstance(parent, Module) + parent.scope.nodes.remove(scope) + + def check_directives(self, module: Module) -> None: + for node, parent in self.recurse_scopes(module): + if isinstance(node, Module): + # check module-level directives + imports = set() # type: Set[str] + for directive in node.scope.filter_nodes(Directive): + if directive.name not in {"output", "zp", "address", "import", "saveregisters"}: + raise ParseError("invalid directive in module", None, directive.sourceref) + if directive.name == "import": + if imports & set(directive.args): + raise ParseError("duplicate import", None, directive.sourceref) + imports |= set(directive.args) + if isinstance(node, (Block, Subroutine)): + # check block and subroutine-level directives + first_node = True + if not node.scope: + continue + for sub_node in node.scope.nodes: + if isinstance(sub_node, Directive): + if sub_node.name not in {"asmbinary", "asminclude", "breakpoint", "saveregisters"}: + raise ParseError("invalid directive in " + node.__class__.__name__.lower(), None, sub_node.sourceref) + if sub_node.name == "saveregisters" and not first_node: + raise ParseError("saveregisters directive should be the first", None, sub_node.sourceref) + first_node = False + + def recurse_scopes(self, module: Module) -> Generator[Tuple[AstNode, AstNode], None, None]: + # generator that recursively yields through the scopes (preorder traversal), yields (node, parent_node) tuples. + yield module, None + for block in list(module.scope.filter_nodes(Block)): + yield block, module + for subroutine in list(block.scope.filter_nodes(Subroutine)): + yield subroutine, block + + def process_imports(self, module: Module) -> None: + # (recursively) imports the modules + imported = [] + for directive in module.scope.filter_nodes(Directive): + if directive.name == "import": + if len(directive.args) < 1: + raise ParseError("missing argument(s) for import directive", None, directive.sourceref) + for arg in directive.args: + filename = self.find_import_file(arg, directive.sourceref.file) + if not filename: + raise ParseError("imported file not found", None, directive.sourceref) + imported_module = self.import_file(filename) + imported_module.scope.parent_scope = module.scope + imported.append(imported_module) + # append the imported module's contents (blocks) at the end of the current module + for imported_module in imported: + for block in imported_module.scope.filter_nodes(Block): + module.scope.nodes.append(block) + + def import_file(self, filename: str) -> Module: + sub_parser = PlyParser() + return sub_parser.parse_file(filename) + + def find_import_file(self, modulename: str, sourcefile: str) -> Optional[str]: + filename_at_source_location = os.path.join(os.path.split(sourcefile)[0], modulename) + filename_at_libs_location = os.path.join(os.getcwd(), "lib", modulename) + candidates = [modulename, + filename_at_source_location, + filename_at_libs_location, + modulename+".ill", + filename_at_source_location+".ill", + filename_at_libs_location+".ill"] + for filename in candidates: + if os.path.isfile(filename): + return filename + return None + + def print_warning(self, text: str, sourceref: SourceRef=None) -> None: + if sourceref: + self.print_bold("warning: {}: {:s}".format(sourceref, text)) + else: + self.print_bold("warning: " + text) + + def print_bold(self, text: str) -> None: + if sys.stdout.isatty(): + print("\x1b[1m" + text + "\x1b[0m", flush=True) + else: + print(text) + + def handle_parse_error(self, exc: ParseError) -> None: + self.parse_errors += 1 + if sys.stderr.isatty(): + print("\x1b[1m", file=sys.stderr) + if self.parsing_import: + print("Error (in imported file):", str(exc), file=sys.stderr) + else: + print("Error:", str(exc), file=sys.stderr) + if exc.sourcetext is None: + exc.sourcetext = linecache.getline(exc.sourceref.file, exc.sourceref.line).rstrip() + if exc.sourcetext: + # remove leading whitespace + stripped = exc.sourcetext.lstrip() + num_spaces = len(exc.sourcetext) - len(stripped) + stripped = stripped.rstrip() + print(" " + stripped, file=sys.stderr) + if exc.sourceref.column: + print(" " + ' ' * (exc.sourceref.column - num_spaces) + '^', file=sys.stderr) + if sys.stderr.isatty(): + print("\x1b[0m", file=sys.stderr, end="", flush=True) + + +if __name__ == "__main__": + plyparser = PlyParser() + m = plyparser.parse_file(sys.argv[1]) + print(str(m)[:400], "...") diff --git a/il65/plyparser.py b/il65/plyparser.py index a8edd2a7a..7687a5716 100644 --- a/il65/plyparser.py +++ b/il65/plyparser.py @@ -8,7 +8,7 @@ License: GNU GPL 3.0, see LICENSE import attr from ply.yacc import yacc -from typing import Union +from typing import Union, Type, Generator from .symbols import SourceRef from .lexer import tokens, lexer, find_tok_column # get the lexer tokens. required. @@ -48,70 +48,121 @@ class AstNode: tostr(self, 0) -@attr.s(cmp=False) -class Module(AstNode): - nodes = attr.ib(type=list) - - -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Directive(AstNode): name = attr.ib(type=str) args = attr.ib(type=list, default=attr.Factory(list)) -@attr.s(cmp=False) +@attr.s(cmp=False, slots=True, repr=False) class Scope(AstNode): nodes = attr.ib(type=list) + symbols = attr.ib(init=False) + name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc. + parent_scope = attr.ib(init=False, default=None) # will be wired up later + save_registers = attr.ib(type=bool, default=False, init=False) # XXX will be set later + + def __attrs_post_init__(self): + # populate the symbol table for this scope for fast lookups via scope["name"] or scope["dotted.name"] + self.symbols = {} + for node in self.nodes: + if isinstance(node, (Label, VarDef)): + self.symbols[node.name] = node + if isinstance(node, Subroutine): + self.symbols[node.name] = node + if node.scope is not None: + node.scope.parent_scope = self + if isinstance(node, Block): + if node.name: + self.symbols[node.name] = node + node.scope.parent_scope = self + + def __getitem__(self, name: str) -> AstNode: + if '.' in name: + # look up the dotted name starting from the topmost scope + scope = self + while scope.parent_scope: + scope = scope.parent_scope + for namepart in name.split('.'): + if isinstance(scope, (Block, Subroutine)): + scope = scope.scope + if not isinstance(scope, Scope): + raise LookupError("undefined symbol: " + name) + scope = scope.symbols.get(namepart, None) + if scope is None: + raise LookupError("undefined symbol: " + name) + return scope + else: + # find the name in nested scope hierarchy + if name in self.symbols: + return self.symbols[name] + if self.parent_scope: + return self.parent_scope[name] + raise LookupError("undefined symbol: " + name) + + def filter_nodes(self, nodetype) -> Generator[AstNode, None, None]: + for node in self.nodes: + if isinstance(node, nodetype): + yield node -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) +class Module(AstNode): + name = attr.ib(type=str) # filename + scope = attr.ib(type=Scope) + + +@attr.s(cmp=False, repr=False) class Block(AstNode): scope = attr.ib(type=Scope) name = attr.ib(type=str, default=None) address = attr.ib(type=int, default=None) + def __attrs_post_init__(self): + self.scope.name = self.name -@attr.s(cmp=False) + +@attr.s(cmp=False, repr=False) class Label(AstNode): name = attr.ib(type=str) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Register(AstNode): name = attr.ib(type=str) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class PreserveRegs(AstNode): registers = attr.ib(type=str) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Assignment(AstNode): left = attr.ib() # type: Union[str, TargetRegisters, Dereference] right = attr.ib() -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class AugAssignment(Assignment): operator = attr.ib(type=str) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class SubCall(AstNode): target = attr.ib() preserve_regs = attr.ib() arguments = attr.ib() -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Return(AstNode): value_A = attr.ib(default=None) value_X = attr.ib(default=None) value_Y = attr.ib(default=None) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class TargetRegisters(AstNode): registers = attr.ib(type=list) @@ -119,12 +170,12 @@ class TargetRegisters(AstNode): self.registers.append(register) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class InlineAssembly(AstNode): assembly = attr.ib(type=str) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class VarDef(AstNode): name = attr.ib(type=str) vartype = attr.ib() @@ -132,13 +183,13 @@ class VarDef(AstNode): value = attr.ib(default=None) -@attr.s(cmp=False, slots=True) +@attr.s(cmp=False, slots=True, repr=False) class Datatype(AstNode): name = attr.ib(type=str) dimension = attr.ib(type=list, default=None) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Subroutine(AstNode): name = attr.ib(type=str) param_spec = attr.ib() @@ -149,40 +200,42 @@ class Subroutine(AstNode): def __attrs_post_init__(self): if self.scope is not None and self.address is not None: raise ValueError("subroutine must have either a scope or an address, not both") + if self.scope is not None: + self.scope.name = self.name -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Goto(AstNode): target = attr.ib() if_stmt = attr.ib(default=None) condition = attr.ib(default=None) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class Dereference(AstNode): location = attr.ib() datatype = attr.ib() -@attr.s(cmp=False, slots=True) +@attr.s(cmp=False, slots=True, repr=False) class CallTarget(AstNode): target = attr.ib() address_of = attr.ib(type=bool) -@attr.s(cmp=False, slots=True) +@attr.s(cmp=False, slots=True, repr=False) class CallArgument(AstNode): value = attr.ib() name = attr.ib(type=str, default=None) -@attr.s(cmp=False) +@attr.s(cmp=False, repr=False) class UnaryOp(AstNode): operator = attr.ib(type=str) operand = attr.ib() -@attr.s(cmp=False, slots=True) +@attr.s(cmp=False, slots=True, repr=False) class Expression(AstNode): left = attr.ib() operator = attr.ib(type=str) @@ -195,7 +248,13 @@ def p_start(p): | module_elements """ if p[1]: - p[0] = Module(nodes=p[1], sourceref=_token_sref(p, 1)) + scope = Scope(nodes=p[1], sourceref=_token_sref(p, 1)) + scope.name = "<" + p.lexer.source_filename + " global scope>" + p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=_token_sref(p, 1)) + else: + scope = Scope(nodes=[], sourceref=_token_sref(p, 1)) + scope.name = "<" + p.lexer.source_filename + " global scope>" + p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=SourceRef(lexer.source_filename, 1, 1)) def p_module(p): @@ -214,8 +273,9 @@ def p_module_elt(p): module_elt : ENDL | directive | block - """ - p[0] = p[1] + """ + if p[1] != '\n': + p[0] = p[1] def p_directive(p): @@ -245,6 +305,7 @@ def p_directive_arg(p): directive_arg : NAME | INTEGER | STRING + | BOOLEAN """ p[0] = p[1] @@ -289,7 +350,7 @@ def p_scope_elements_opt(p): """ scope_elements_opt : empty | scope_elements - """ + """ p[0] = p[1] @@ -297,11 +358,14 @@ def p_scope_elements(p): """ scope_elements : scope_element | scope_elements scope_element - """ + """ if len(p) == 2: - p[0] = [p[1]] + p[0] = [] if p[1] in (None, '\n') else [p[1]] else: - p[0] = p[1] + [p[2]] + if p[2] in (None, '\n'): + p[0] = p[1] + else: + p[0] = p[1] + [p[2]] def p_scope_element(p): @@ -314,7 +378,10 @@ def p_scope_element(p): | inlineasm | statement """ - p[0] = p[1] + if p[1] != '\n': + p[0] = p[1] + else: + p[0] = None def p_label(p): @@ -729,17 +796,18 @@ def p_empty(p): def p_error(p): + stack_state_str = ' '.join([symbol.type for symbol in parser.symstack][1:]) + print('\n[ERROR DEBUG: parser state={:d} stack: {} . {} ]'.format(parser.state, stack_state_str, p)) if p: sref = SourceRef(p.lexer.source_filename, p.lineno, find_tok_column(p)) - p.lexer.error_function("{}: before '{:.20s}' ({})", sref, str(p.value), repr(p)) + p.lexer.error_function("syntax error before '{:.20s}'", str(p.value), sourceref=sref) else: - lexer.error_function("{}: at end of input", "@todo-filename3") + lexer.error_function("syntax error at end of input", lexer.source_filename, sourceref=None) def _token_sref(p, token_idx): """ Returns the coordinates for the YaccProduction object 'p' indexed - with 'token_idx'. The coordinate includes the 'lineno' and - 'column'. Both follow the lex semantic, starting from 1. + with 'token_idx'. The coordinate includes the 'lineno' and 'column', starting from 1. """ last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx)) if last_cr < 0: @@ -772,12 +840,10 @@ class TokenFilter: parser = yacc(write_tables=True) -if __name__ == "__main__": - import sys - file = sys.stdin # open(sys.argv[1], "rU") - lexer.source_filename = "derp" - tokenfilter = TokenFilter(lexer) - result = parser.parse(input=file.read(), - tokenfunc=tokenfilter.token) or Module(None, SourceRef(lexer.source_filename, 1, 1)) - print("RESULT:") - result.print_tree() +def parse_file(filename: str) -> Module: + lexer.lineno = 1 + lexer.source_filename = filename + tfilter = TokenFilter(lexer) + with open(filename, "rU") as inf: + sourcecode = inf.read() + return parser.parse(input=sourcecode, tokenfunc=tfilter.token) diff --git a/il65/symbols.py b/il65/symbols.py index c02df38a2..3898a1839 100644 --- a/il65/symbols.py +++ b/il65/symbols.py @@ -710,17 +710,18 @@ class AstNode: try: variables = vars(node).items() except TypeError: - variables = {} - for name, value in variables: - if isinstance(value, AstNode): - tostr(value, level + 1) - if isinstance(value, (list, tuple, set)): - if len(value) > 0: - elt = list(value)[0] - if isinstance(elt, AstNode) or name == "nodes": - print(indent, " >", name, "=") - for elt in value: - tostr(elt, level + 2) + pass + else: + for name, value in variables: + if isinstance(value, AstNode): + tostr(value, level + 1) + if isinstance(value, (list, tuple, set)): + if len(value) > 0: + elt = list(value)[0] + if isinstance(elt, AstNode) or name == "nodes": + print(indent, " >", name, "=") + for elt in value: + tostr(elt, level + 2) tostr(self, 0) diff --git a/lib/mathlib.ill b/lib/mathlib.ill index e594c9b34..74f80617b 100644 --- a/lib/mathlib.ill +++ b/lib/mathlib.ill @@ -59,7 +59,6 @@ _m_with_add stx SCRATCH_ZP1 } } - sub multiply_bytes_addA_16 (byte1: X, byte2: Y, add: A) -> (A?, XY) { ; ---- multiply 2 bytes and add A, result as word in X/Y (unsigned) %asm { diff --git a/reference.md b/reference.md index a07265c78..ec0982860 100644 --- a/reference.md +++ b/reference.md @@ -400,7 +400,7 @@ Normally, the registers are NOT preserved when calling a subroutine or when a ce operations are performed. Most calls will be simply a few instructions to load the values in the registers and then a JSR or JMP. -By using the ``%preserve_registers`` directive (globally or in a block) you can tell the +By using the ``%saveregisters`` directive (globally or in a block) you can tell the compiler to preserve all registers. This does generate a lot of extra code that puts original values on the stack and gets them off the stack again once the subroutine is done. In this case however you don't have to worry about A, X and Y losing their original values diff --git a/todo.ill b/todo.ill index 1f43b7410..1e08cdc1f 100644 --- a/todo.ill +++ b/todo.ill @@ -4,12 +4,17 @@ ~ main $4444 { + %saveregisters true, false + const num = 2 var var1 =2 var .word wvar1 = 2 + start: + + A=math.randbyte() A += c64.RASTER A-=c64.TIME_LO @@ -161,6 +166,38 @@ loop : c64scr.print_word_decimal(1222) c64.CHROUT('\n') + %breakpoint + return +sub sub1 () -> () { + + %saveregisters off + %breakpoint + +label: + return + +} + +sub emptysub () -> () { + + %saveregisters on + +} + +} + +~ zzzz { + + %saveregisters + %breakpoint + + return + +} +~ { + ;sdfsdf + return + ;sdfsdf }