""" Programming Language for 6502/6510 microprocessors, codename 'Sick' This is the parser of the IL65 code, that generates a parse tree. Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ import math import builtins import inspect import enum from collections import defaultdict from typing import Union, Generator, Tuple, List, Optional, Dict, Any, Iterable import attr from ply.yacc import yacc from .plylex import SourceRef, tokens, lexer, find_tok_column, print_warning from .datatypes import DataType, VarType, REGISTER_SYMBOLS, REGISTER_BYTES, REGISTER_WORDS, \ char_to_bytevalue, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE class ProgramFormat(enum.Enum): RAW = "raw" PRG = "prg" BASIC = "basicprg" class ZpOptions(enum.Enum): NOCLOBBER = "noclobber" CLOBBER = "clobber" CLOBBER_RESTORE = "clobber_restore" math_functions = {name: func for name, func in vars(math).items() if inspect.isbuiltin(func)} builtin_functions = {name: func for name, func in vars(builtins).items() if inspect.isbuiltin(func)} class ParseError(Exception): def __init__(self, message: str, sourceref: SourceRef) -> None: super().__init__(message) self.sourceref = sourceref # @todo chain attribute, a list of other exceptions, so we can have more than 1 error at a time. def __str__(self): return "{} {:s}".format(self.sourceref, self.args[0]) class ExpressionEvaluationError(ParseError): pass class UndefinedSymbolError(LookupError): pass start = "start" @attr.s(cmp=False, slots=True, frozen=False) class AstNode: sourceref = attr.ib(type=SourceRef) # when evaluating an expression, does it have to be a constant value?: processed_expr_must_be_constant = attr.ib(type=bool, init=False, default=False) @property def lineref(self) -> str: return "src l. " + str(self.sourceref.line) def print_tree(self) -> None: def tostr(node: AstNode, level: int) -> None: if not isinstance(node, AstNode): return indent = " " * level name = getattr(node, "name", "") print(indent, node.__class__.__name__, repr(name)) try: variables = vars(node).items() except TypeError: return for name, value in variables: if isinstance(value, AstNode): tostr(value, level + 1) if isinstance(value, (list, tuple, set)): if len(value) > 0: elt = list(value)[0] if isinstance(elt, AstNode) or name == "nodes": print(indent, " >", name, "=") for elt in value: tostr(elt, level + 2) tostr(self, 0) def process_expressions(self, scope: 'Scope') -> None: # process/simplify all expressions (constant folding etc) # this is implemented in node types that have expression(s) and that should act on this. pass def verify_symbol_names(self, scope: 'Scope') -> None: # check all SymbolNames to see if they exist. # this is implemented in node types that have expression(s) and that should act on this. pass @attr.s(cmp=False, repr=False) class Directive(AstNode): name = attr.ib(type=str) args = attr.ib(type=list, default=attr.Factory(list)) @attr.s(cmp=False, slots=True, repr=False) class Scope(AstNode): nodes = attr.ib(type=list) symbols = attr.ib(init=False) name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc. parent_scope = attr.ib(init=False, default=None) # will be wired up later save_registers = attr.ib(type=bool, default=None, init=False) # None = look in parent scope's setting @todo property that does that def __attrs_post_init__(self): # populate the symbol table for this scope for fast lookups via scope.lookup("name") or scope.lookup("dotted.name") self.symbols = {} for node in self.nodes: assert isinstance(node, AstNode) self._populate_symboltable(node) def _populate_symboltable(self, node: AstNode) -> None: if isinstance(node, (Label, VarDef)): if node.name in self.symbols: raise ParseError("symbol already defined at {}".format(self.symbols[node.name].sourceref), node.sourceref) self.symbols[node.name] = node if isinstance(node, Subroutine): if node.name in self.symbols: raise ParseError("symbol already defined at {}".format(self.symbols[node.name].sourceref), node.sourceref) self.symbols[node.name] = node if node.scope: node.scope.parent_scope = self if isinstance(node, Block): if node.name: if node.name != "ZP" and node.name in self.symbols: raise ParseError("symbol already defined at {}".format(self.symbols[node.name].sourceref), node.sourceref) self.symbols[node.name] = node node.scope.parent_scope = self def lookup(self, name: str) -> AstNode: assert isinstance(name, str) if '.' in name: # look up the dotted name starting from the topmost scope scope = self while scope.parent_scope: scope = scope.parent_scope for namepart in name.split('.'): if isinstance(scope, (Block, Subroutine)): scope = scope.scope if not isinstance(scope, Scope): raise UndefinedSymbolError("undefined symbol: " + name) scope = scope.symbols.get(namepart, None) if not scope: raise UndefinedSymbolError("undefined symbol: " + name) return scope else: # find the name in nested scope hierarchy if name in self.symbols: return self.symbols[name] if self.parent_scope: return self.parent_scope.lookup(name) raise UndefinedSymbolError("undefined symbol: " + name) def filter_nodes(self, nodetype) -> Generator[AstNode, None, None]: for node in self.nodes: if isinstance(node, nodetype): yield node def remove_node(self, node: AstNode) -> None: if hasattr(node, "name"): try: del self.symbols[node.name] # type: ignore except KeyError: pass self.nodes.remove(node) def replace_node(self, oldnode: AstNode, newnode: AstNode) -> None: assert isinstance(newnode, AstNode) idx = self.nodes.index(oldnode) self.nodes[idx] = newnode if hasattr(oldnode, "name"): del self.symbols[oldnode.name] # type: ignore def add_node(self, newnode: AstNode, index: int=None) -> None: assert isinstance(newnode, AstNode) if index is None: self.nodes.append(newnode) else: self.nodes.insert(index, newnode) self._populate_symboltable(newnode) def validate_address(obj: AstNode, attrib: attr.Attribute, value: Optional[int]) -> None: if value is None: return if isinstance(obj, Block) and obj.name == "ZP": raise ParseError("zeropage block cannot have custom start {:s}".format(attrib.name), obj.sourceref) if value < 0x0200 or value > 0xffff: raise ParseError("invalid {:s} (must be from $0200 to $ffff)".format(attrib.name), obj.sourceref) def dimensions_validator(obj: 'DatatypeNode', attrib: attr.Attribute, value: List[int]) -> None: if not value: return dt = obj.to_enum() if value and dt not in (DataType.MATRIX, DataType.WORDARRAY, DataType.BYTEARRAY): raise ParseError("cannot use a dimension for this datatype", obj.sourceref) if dt == DataType.WORDARRAY or dt == DataType.BYTEARRAY: if len(value) == 1: if value[0] <= 0 or value[0] > 256: raise ParseError("array length must be 1..256", obj.sourceref) else: raise ParseError("array must have only one dimension", obj.sourceref) if dt == DataType.MATRIX: if len(value) < 2 or len(value) > 3: raise ParseError("matrix must have two dimensions, with optional interleave", obj.sourceref) if len(value) == 3: if value[2] < 1 or value[2] > 256: raise ParseError("matrix interleave must be 1..256", obj.sourceref) if value[0] < 0 or value[0] > 128 or value[1] < 0 or value[1] > 128: raise ParseError("matrix rows and columns must be 1..128", obj.sourceref) @attr.s(cmp=False, repr=False) class Block(AstNode): scope = attr.ib(type=Scope) name = attr.ib(type=str, default=None) address = attr.ib(type=int, default=None, validator=validate_address) _unnamed_block_labels = {} # type: Dict[Block, str] def __attrs_post_init__(self): self.scope.name = self.name @property def nodes(self) -> Iterable[AstNode]: if self.scope: return self.scope.nodes return [] @property def label(self) -> str: if self.name: return self.name if self in self._unnamed_block_labels: return self._unnamed_block_labels[self] label = "il65_block_{:d}".format(len(self._unnamed_block_labels)) self._unnamed_block_labels[self] = label return label @attr.s(cmp=False, repr=False) class Module(AstNode): name = attr.ib(type=str) # filename scope = attr.ib(type=Scope) subroutine_usage = attr.ib(type=defaultdict, init=False, default=attr.Factory(lambda: defaultdict(set))) # will be populated later format = attr.ib(type=ProgramFormat, init=False, default=ProgramFormat.PRG) # can be set via directive address = attr.ib(type=int, init=False, default=0xc000, validator=validate_address) # can be set via directive zp_options = attr.ib(type=ZpOptions, init=False, default=ZpOptions.NOCLOBBER) # can be set via directive @property def nodes(self) -> Iterable[AstNode]: if self.scope: return self.scope.nodes return [] def all_scopes(self) -> Generator[Tuple[AstNode, AstNode], None, None]: # generator that recursively yields through the scopes (preorder traversal), yields (node, parent_node) tuples. # it iterates of copies of the node collections, so it's okay to modify the scopes you iterate over. yield self, None for block in list(self.scope.filter_nodes(Block)): yield block, self for subroutine in list(block.scope.filter_nodes(Subroutine)): yield subroutine, block def zeropage(self) -> Optional[Block]: # return the zeropage block (if defined) first_block = next(self.scope.filter_nodes(Block)) if first_block.name == "ZP": return first_block return None def main(self) -> Optional[Block]: # return the 'main' block (if defined) for block in self.scope.filter_nodes(Block): if block.name == "main": return block return None @attr.s(cmp=False, repr=False) class Label(AstNode): name = attr.ib(type=str) @attr.s(cmp=False, repr=False) class Register(AstNode): name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS)) datatype = attr.ib(type=DataType, init=False) def __attrs_post_init__(self): if self.name in REGISTER_BYTES: self.datatype = DataType.BYTE elif self.name in REGISTER_WORDS: self.datatype = DataType.WORD else: self.datatype = None # register 'SC' etc. def __hash__(self) -> int: return hash(self.name) def __eq__(self, other) -> bool: if not isinstance(other, Register): return NotImplemented return self.name == other.name def __lt__(self, other) -> bool: if not isinstance(other, Register): return NotImplemented return self.name < other.name @attr.s(cmp=False, repr=False) class PreserveRegs(AstNode): registers = attr.ib(type=str) @attr.s(cmp=False, repr=False) class Assignment(AstNode): # can be single- or multi-assignment left = attr.ib(type=list) # type: List[Union[str, TargetRegisters, Dereference]] right = attr.ib() def process_expressions(self, scope: Scope) -> None: self.right = process_expression(self.right, scope, self.right.sourceref) def verify_symbol_names(self, scope: Scope) -> None: for lv in self.left: if isinstance(lv, SymbolName): check_symbol_definition(lv.name, scope, lv.sourceref) elif isinstance(lv, Dereference): if isinstance(lv.location, SymbolName): check_symbol_definition(lv.location.name, scope, lv.location.sourceref) # the symbols in the assignment rvalue are checked when its expression is processed. @attr.s(cmp=False, repr=False) class AugAssignment(AstNode): left = attr.ib() operator = attr.ib(type=str) right = attr.ib() def process_expressions(self, scope: Scope) -> None: self.right = process_expression(self.right, scope, self.right.sourceref) def verify_symbol_names(self, scope: Scope) -> None: if isinstance(self.left, SymbolName): check_symbol_definition(self.left.name, scope, self.left.sourceref) # the symbols in the assignment rvalue are checked when its expression is processed. @attr.s(cmp=False, repr=False) class SubCall(AstNode): target = attr.ib() preserve_regs = attr.ib() arguments = attr.ib() def __attrs_post_init__(self): self.arguments = self.arguments or [] def process_expressions(self, scope: Scope) -> None: for callarg in self.arguments: assert isinstance(callarg, CallArgument) callarg.process_expressions(scope) def verify_symbol_names(self, scope: Scope) -> None: if isinstance(self.target.target, SymbolName): check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref) # the symbols in the subroutine's arguments are checked when their expression is processed. @attr.s(cmp=False, repr=False) class Return(AstNode): value_A = attr.ib(default=None) value_X = attr.ib(default=None) value_Y = attr.ib(default=None) def process_expressions(self, scope: Scope) -> None: if self.value_A is not None: self.value_A = process_expression(self.value_A, scope, self.sourceref) if isinstance(self.value_A, (int, float, str, bool)): try: _, self.value_A = coerce_constant_value(DataType.BYTE, self.value_A, self.sourceref) except (OverflowError, TypeError) as x: raise ParseError("first value (A): " + str(x), self.sourceref) from None if self.value_X is not None: self.value_X = process_expression(self.value_X, scope, self.sourceref) if isinstance(self.value_X, (int, float, str, bool)): try: _, self.value_X = coerce_constant_value(DataType.BYTE, self.value_X, self.sourceref) except (OverflowError, TypeError) as x: raise ParseError("second value (X): " + str(x), self.sourceref) from None if self.value_Y is not None: self.value_Y = process_expression(self.value_Y, scope, self.sourceref) if isinstance(self.value_Y, (int, float, str, bool)): try: _, self.value_Y = coerce_constant_value(DataType.BYTE, self.value_Y, self.sourceref) except (OverflowError, TypeError) as x: raise ParseError("third value (Y): " + str(x), self.sourceref) from None @attr.s(cmp=False, repr=False) class TargetRegisters(AstNode): # This is a tuple of 1 or more registers. # In it's multiple-register form it is only used to be able to parse # the result of a subroutine call such as A,X = sub(). # It will be replaced by a regular Register node if it contains just one register. registers = attr.ib(type=list) def add(self, register: str) -> None: self.registers.append(register) @attr.s(cmp=False, repr=False) class InlineAssembly(AstNode): assembly = attr.ib(type=str) @attr.s(cmp=False, repr=True, slots=True) class VarDef(AstNode): name = attr.ib(type=str) vartype = attr.ib() datatype = attr.ib() value = attr.ib(default=None) size = attr.ib(type=list, default=None) zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later def __attrs_post_init__(self): # convert vartype to enum if self.vartype == "const": self.vartype = VarType.CONST elif self.vartype == "var": self.vartype = VarType.VAR elif self.vartype == "memory": self.vartype = VarType.MEMORY else: raise ValueError("invalid vartype", self.vartype) # convert datatype node to enum + size if self.datatype is None: assert self.size is None self.size = [1] self.datatype = DataType.BYTE elif isinstance(self.datatype, DatatypeNode): assert self.size is None self.size = self.datatype.dimensions or [1] self.datatype = self.datatype.to_enum() if self.datatype.isarray() and sum(self.size) in (0, 1): print("warning: {}: array/matrix with size 1, use normal byte/word instead for efficiency".format(self.sourceref)) if self.vartype == VarType.CONST and self.value is None: raise ParseError("constant value assignment is missing", attr.evolve(self.sourceref, column=self.sourceref.column+len(self.name))) # if the value is an expression, mark it as a *constant* expression here if isinstance(self.value, AstNode): self.value.processed_expr_must_be_constant = True elif self.value is None and (self.datatype.isnumeric() or self.datatype.isarray()): self.value = 0 # if it's a matrix with interleave, it must be memory mapped if self.datatype == DataType.MATRIX and len(self.size) == 3: if self.vartype != VarType.MEMORY: raise ParseError("matrix with interleave can only be a memory-mapped variable", self.sourceref) # note: value coercion is done later, when all expressions are evaluated def process_expressions(self, scope: Scope) -> None: self.value = process_expression(self.value, scope, self.sourceref) assert not isinstance(self.value, Expression), "processed expression for vardef should reduce to a constant value" if self.vartype in (VarType.CONST, VarType.VAR): try: _, self.value = coerce_constant_value(self.datatype, self.value, self.sourceref) except OverflowError as x: raise ParseError(str(x), self.sourceref) from None except TypeError as x: raise ParseError("processed expression vor vardef is not a constant value: " + str(x), self.sourceref) from None @attr.s(cmp=False, slots=True, repr=False) class DatatypeNode(AstNode): name = attr.ib(type=str) dimensions = attr.ib(type=list, default=None, validator=dimensions_validator) # if set, 1 or more dimensions (ints) def to_enum(self): return { "byte": DataType.BYTE, "word": DataType.WORD, "float": DataType.FLOAT, "text": DataType.STRING, "ptext": DataType.STRING_P, "stext": DataType.STRING_S, "pstext": DataType.STRING_PS, "matrix": DataType.MATRIX, "array": DataType.BYTEARRAY, "wordarray": DataType.WORDARRAY }[self.name] @attr.s(cmp=False, repr=False) class Subroutine(AstNode): name = attr.ib(type=str) param_spec = attr.ib(type=list) result_spec = attr.ib(type=list) scope = attr.ib(type=Scope, default=None) address = attr.ib(type=int, default=None, validator=validate_address) @property def nodes(self) -> Iterable[AstNode]: if self.scope: return self.scope.nodes return [] def __attrs_post_init__(self): if self.scope and self.address is not None: raise ValueError("subroutine must have either a scope or an address, not both") if self.scope: self.scope.name = self.name @attr.s(cmp=False, repr=False) class Goto(AstNode): target = attr.ib() if_stmt = attr.ib(default=None) condition = attr.ib(default=None) def process_expressions(self, scope: Scope) -> None: if self.condition is not None: self.condition = process_expression(self.condition, scope, self.condition.sourceref) def verify_symbol_names(self, scope: Scope) -> None: if isinstance(self.target.target, SymbolName): check_symbol_definition(self.target.target.name, scope, self.target.target.sourceref) @attr.s(cmp=False, repr=False) class Dereference(AstNode): location = attr.ib() datatype = attr.ib() size = attr.ib(type=int, default=None) def __attrs_post_init__(self): # convert datatype node to enum + size if self.datatype is None: assert self.size is None self.size = 1 self.datatype = DataType.BYTE elif isinstance(self.datatype, DatatypeNode): assert self.size is None self.size = self.datatype.dimensions if not self.datatype.to_enum().isnumeric(): raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref) self.datatype = self.datatype.to_enum() def verify_symbol_names(self, scope: Scope) -> None: print("DEREF", self.location) # XXX not called????? if isinstance(self.location, SymbolName): check_symbol_definition(self.location.name, scope, self.location.sourceref) @attr.s(cmp=False, repr=False) class LiteralValue(AstNode): value = attr.ib() @attr.s(cmp=False, repr=False) class AddressOf(AstNode): name = attr.ib(type=str) @attr.s(cmp=False, repr=False) class IncrDecr(AstNode): # increment or decrement something by a CONSTANT value (1 or more) target = attr.ib() operator = attr.ib(type=str, validator=attr.validators.in_(["++", "--"])) howmuch = attr.ib(default=1) def __attrs_post_init__(self): # make sure the amount is always >= 0 if self.howmuch < 0: self.howmuch = -self.howmuch self.operator = "++" if self.operator == "--" else "--" if isinstance(self.target, Register): if self.target.name not in REGISTER_BYTES | REGISTER_WORDS: raise ParseError("cannot incr/decr that register", self.sourceref) if isinstance(self.target, TargetRegisters): raise ParseError("cannot incr/decr multiple registers at once", self.sourceref) def verify_symbol_names(self, scope: Scope) -> None: if isinstance(self.target, SymbolName): check_symbol_definition(self.target.name, scope, self.target.sourceref) @attr.s(cmp=False, repr=False) class SymbolName(AstNode): name = attr.ib(type=str) @attr.s(cmp=False, slots=True, repr=False) class CallTarget(AstNode): target = attr.ib() address_of = attr.ib(type=bool) @attr.s(cmp=False, slots=True, repr=False) class CallArgument(AstNode): value = attr.ib() name = attr.ib(type=str, default=None) def process_expressions(self, scope: Scope) -> None: self.value = process_expression(self.value, scope, self.sourceref) @attr.s(cmp=False, slots=True, repr=False) class Expression(AstNode): left = attr.ib() operator = attr.ib(type=str) right = attr.ib() unary = attr.ib(type=bool, default=False) def __attrs_post_init__(self): assert self.operator not in ("++", "--"), "incr/decr should not be an expression" def process_expressions(self, scope: Scope) -> None: raise RuntimeError("must be done via parent node's process_expressions") def evaluate_primitive_constants(self, scope: Scope) -> Union[int, float, str, bool]: # make sure the lvalue and rvalue are primitives, and the operator is allowed if not isinstance(self.left, (LiteralValue, int, float, str, bool)): raise TypeError("left", self) if not isinstance(self.right, (LiteralValue, int, float, str, bool)): raise TypeError("right", self) if self.operator not in {'+', '-', '*', '/', '//', '~', '<', '>', '<=', '>=', '==', '!='}: raise ValueError("operator", self) estr = "{} {} {}".format(repr(self.left), self.operator, repr(self.right)) try: return eval(estr, {}, {}) # safe because of checks above except Exception as x: raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None def print_tree(self) -> None: def tree(expr: Any, level: int) -> str: indent = " "*level if not isinstance(expr, Expression): return indent + str(expr) + "\n" if expr.unary: return indent + "{}{}".format(expr.operator, tree(expr.left, level+1)) else: return indent + "{}".format(tree(expr.left, level+1)) + \ indent + str(expr.operator) + "\n" + \ indent + "{}".format(tree(expr.right, level + 1)) print(tree(self, 0)) def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType: # tries to determine the DataType of an assignment target node if isinstance(assignmenttarget, (VarDef, Dereference, Register)): return assignmenttarget.datatype elif isinstance(assignmenttarget, SymbolName): symdef = scope.lookup(assignmenttarget.name) if isinstance(symdef, VarDef): return symdef.datatype elif isinstance(assignmenttarget, TargetRegisters): if len(assignmenttarget.registers) == 1: return datatype_of(assignmenttarget.registers[0], scope) raise TypeError("cannot determine datatype", assignmenttarget) def coerce_constant_value(datatype: DataType, value: Any, sourceref: SourceRef=None) -> Tuple[bool, Any]: # if we're a BYTE type, and the value is a single character, convert it to the numeric value def verify_bounds(value: Union[int, float, str]) -> None: # if the value is out of bounds, raise an overflow exception if isinstance(value, (int, float)): if datatype == DataType.BYTE and not (0 <= value <= 0xff): # type: ignore raise OverflowError("value out of range for byte: " + str(value)) if datatype == DataType.WORD and not (0 <= value <= 0xffff): # type: ignore raise OverflowError("value out of range for word: " + str(value)) if datatype == DataType.FLOAT and not (FLOAT_MAX_NEGATIVE <= value <= FLOAT_MAX_POSITIVE): # type: ignore raise OverflowError("value out of range for float: " + str(value)) if isinstance(value, str) and len(value) == 1 and (datatype.isnumeric() or datatype.isarray()): return True, char_to_bytevalue(value) # if we're an integer value and the passed value is float, truncate it (and give a warning) if datatype in (DataType.BYTE, DataType.WORD, DataType.MATRIX) and isinstance(value, float): frac = math.modf(value) if frac != 0: print_warning("float value truncated ({} to datatype {})".format(value, datatype.name), sourceref=sourceref) value = int(value) verify_bounds(value) return True, value if isinstance(value, (int, float)): verify_bounds(value) if isinstance(value, (Expression, SubCall)): return False, value elif datatype == DataType.WORD: if not isinstance(value, (int, float, str, Dereference, Register, SymbolName, AddressOf)): raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): if not isinstance(value, (int, float, Dereference, Register, SymbolName)): raise TypeError("cannot assign '{:s}' to {:s}".format(type(value).__name__, datatype.name.lower()), sourceref) return False, value def process_expression(value: Any, scope: Scope, sourceref: SourceRef) -> Any: # process/simplify all expressions (constant folding etc) if isinstance(value, AstNode): must_be_constant = value.processed_expr_must_be_constant else: must_be_constant = False if must_be_constant: return process_constant_expression(value, sourceref, scope) else: return process_dynamic_expression(value, sourceref, scope) def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> Union[int, float, str, bool]: # the expression must result in a single (constant) value (int, float, whatever) if expr is None or isinstance(expr, (int, float, str, bool)): return expr elif isinstance(expr, LiteralValue): return expr.value elif isinstance(expr, SymbolName): value = check_symbol_definition(expr.name, symbolscope, expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) value = value.value if isinstance(value, Expression): raise ExpressionEvaluationError("circular reference?", expr.sourceref) elif isinstance(value, (int, float, str, bool)): return value else: raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref) elif isinstance(expr, AddressOf): assert isinstance(expr.name, SymbolName) value = check_symbol_definition(expr.name.name, symbolscope, expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: return value.value if value.vartype == VarType.CONST: raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref) raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant" .format(value.__class__.__name__), expr.name.sourceref) else: raise ExpressionEvaluationError("constant address required, not {}" .format(value.__class__.__name__), expr.name.sourceref) elif isinstance(expr, SubCall): if isinstance(expr.target, CallTarget): target = expr.target.target if isinstance(target, SymbolName): # 'function(1,2,3)' funcname = target.name if funcname in math_functions or funcname in builtin_functions: func_args = [] for a in (process_constant_expression(callarg.value, sourceref, symbolscope) for callarg in expr.arguments): if isinstance(a, LiteralValue): func_args.append(a.value) else: func_args.append(a) func = math_functions.get(funcname, builtin_functions.get(funcname)) try: return func(*func_args) except Exception as x: raise ExpressionEvaluationError(str(x), expr.sourceref) else: raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) elif isinstance(target, Dereference): # '[...](1,2,3)' raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) elif type(target) is int: # '64738()' raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) else: raise NotImplementedError("weird call target", target) else: raise ParseError("function name required, not {}".format(expr.target.__class__.__name__), expr.sourceref) elif not isinstance(expr, Expression): raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) if expr.unary: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = process_constant_expression(expr.left, left_sourceref, symbolscope) if isinstance(expr.left, (int, float)): try: if expr.operator == '-': return -expr.left elif expr.operator == '~': return ~expr.left # type: ignore elif expr.operator in ("++", "--"): raise ValueError("incr/decr should not be an expression") raise ValueError("invalid unary operator", expr.operator) except TypeError as x: raise ParseError(str(x), expr.sourceref) from None raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) else: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = process_constant_expression(expr.left, left_sourceref, symbolscope) right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = process_constant_expression(expr.right, right_sourceref, symbolscope) if isinstance(expr.left, (LiteralValue, int, float, str, bool)): if isinstance(expr.right, (LiteralValue, int, float, str, bool)): return expr.evaluate_primitive_constants(symbolscope) else: raise ExpressionEvaluationError("constant value required on right, not {}" .format(expr.right.__class__.__name__), right_sourceref) else: raise ExpressionEvaluationError("constant value required on left, not {}" .format(expr.left.__class__.__name__), left_sourceref) def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any: try: return scope.lookup(name) except UndefinedSymbolError as x: raise ParseError(str(x), sref) def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> Any: # constant-fold a dynamic expression if expr is None or isinstance(expr, (int, float, str, bool)): return expr elif isinstance(expr, LiteralValue): return expr.value elif isinstance(expr, SymbolName): try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: return expr elif isinstance(expr, AddressOf): try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: return expr elif isinstance(expr, SubCall): try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: if isinstance(expr.target.target, SymbolName): check_symbol_definition(expr.target.target.name, symbolscope, expr.target.target.sourceref) return expr elif isinstance(expr, Register): return expr elif isinstance(expr, Dereference): if isinstance(expr.location, SymbolName): check_symbol_definition(expr.location.name, symbolscope, expr.location.sourceref) return expr elif not isinstance(expr, Expression): raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) if expr.unary: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = process_dynamic_expression(expr.left, left_sourceref, symbolscope) try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: return expr else: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = process_dynamic_expression(expr.left, left_sourceref, symbolscope) right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = process_dynamic_expression(expr.right, right_sourceref, symbolscope) try: return process_constant_expression(expr, sourceref, symbolscope) except ExpressionEvaluationError: return expr # ----------------- PLY parser definition follows ---------------------- def p_start(p): """ start : empty | module_elements """ if p[1]: scope = Scope(nodes=p[1], sourceref=_token_sref(p, 1)) scope.name = "<" + p.lexer.source_filename + " global scope>" p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=_token_sref(p, 1)) else: scope = Scope(nodes=[], sourceref=_token_sref(p, 1)) scope.name = "<" + p.lexer.source_filename + " global scope>" p[0] = Module(name=p.lexer.source_filename, scope=scope, sourceref=SourceRef(lexer.source_filename, 1, 1)) def p_module(p): """ module_elements : module_elt | module_elements module_elt """ if len(p) == 2: if p[1] is None: p[0] = [] else: p[0] = [p[1]] else: if p[2] is None: p[0] = p[1] else: p[0] = p[1] + [p[2]] def p_module_elt(p): """ module_elt : ENDL | directive | block """ if p[1] != '\n': p[0] = p[1] def p_directive(p): """ directive : DIRECTIVE ENDL | DIRECTIVE directive_args ENDL """ if len(p) == 3: p[0] = Directive(name=p[1], sourceref=_token_sref(p, 1)) else: p[0] = Directive(name=p[1], args=p[2], sourceref=_token_sref(p, 1)) def p_directive_args(p): """ directive_args : directive_arg | directive_args ',' directive_arg """ if len(p) == 2: p[0] = [p[1]] else: p[0] = p[1] + [p[3]] def p_directive_arg(p): """ directive_arg : NAME | INTEGER | STRING | BOOLEAN """ p[0] = p[1] def p_block_name_addr(p): """ block : BITINVERT NAME INTEGER endl_opt scope """ p[0] = Block(name=p[2], address=p[3], scope=p[5], sourceref=_token_sref(p, 2)) def p_block_name(p): """ block : BITINVERT NAME endl_opt scope """ p[0] = Block(name=p[2], scope=p[4], sourceref=_token_sref(p, 2)) def p_block(p): """ block : BITINVERT endl_opt scope """ p[0] = Block(scope=p[3], sourceref=_token_sref(p, 1)) def p_endl_opt(p): """ endl_opt : empty | ENDL """ pass def p_scope(p): """ scope : '{' scope_elements_opt '}' """ p[0] = Scope(nodes=p[2] or [], sourceref=_token_sref(p, 1)) def p_scope_elements_opt(p): """ scope_elements_opt : empty | scope_elements """ p[0] = p[1] def p_scope_elements(p): """ scope_elements : scope_element | scope_elements scope_element """ if len(p) == 2: p[0] = [] if p[1] in (None, '\n') else [p[1]] else: if p[2] in (None, '\n'): p[0] = p[1] else: p[0] = p[1] + [p[2]] def p_scope_element(p): """ scope_element : ENDL | label | directive | vardef | subroutine | inlineasm | statement """ if p[1] != '\n': p[0] = p[1] else: p[0] = None def p_label(p): """ label : LABEL """ p[0] = Label(name=p[1], sourceref=_token_sref(p, 1)) def p_inlineasm(p): """ inlineasm : INLINEASM ENDL """ p[0] = InlineAssembly(assembly=p[1], sourceref=_token_sref(p, 1)) def p_vardef(p): """ vardef : VARTYPE type_opt NAME ENDL """ p[0] = VarDef(name=p[3], vartype=p[1], datatype=p[2], sourceref=_token_sref(p, 3)) def p_vardef_value(p): """ vardef : VARTYPE type_opt NAME IS expression """ p[0] = VarDef(name=p[3], vartype=p[1], datatype=p[2], value=p[5], sourceref=_token_sref(p, 3)) def p_type_opt(p): """ type_opt : DATATYPE '(' dimensions ')' | DATATYPE | empty """ if len(p) == 5: p[0] = DatatypeNode(name=p[1], dimensions=p[3], sourceref=_token_sref(p, 1)) elif len(p) == 2 and p[1]: p[0] = DatatypeNode(name=p[1], sourceref=_token_sref(p, 1)) def p_dimensions(p): """ dimensions : INTEGER | dimensions ',' INTEGER """ if len(p) == 2: p[0] = [p[1]] else: p[0] = p[1] + [p[3]] def p_literal_value(p): """literal_value : INTEGER | FLOATINGPOINT | STRING | CHARACTER | BOOLEAN""" tok = p.slice[-1] if tok.type == "CHARACTER": p[1] = char_to_bytevalue(p[1]) # character literals are converted to byte value. elif tok.type == "BOOLEAN": p[1] = int(p[1]) # boolean literals are converted to integer form (true=1, false=0). p[0] = LiteralValue(value=p[1], sourceref=_token_sref(p, 1)) def p_subroutine(p): """ subroutine : SUB NAME '(' sub_param_spec ')' RARROW '(' sub_result_spec ')' subroutine_body ENDL """ body = p[10] if isinstance(body, Scope): p[0] = Subroutine(name=p[2], param_spec=p[4] or [], result_spec=p[8] or [], scope=body, sourceref=_token_sref(p, 1)) elif type(body) is int: p[0] = Subroutine(name=p[2], param_spec=p[4] or [], result_spec=p[8] or [], address=body, sourceref=_token_sref(p, 1)) else: raise TypeError("subroutine_body", p.slice) def p_sub_param_spec(p): """ sub_param_spec : empty | sub_param_list """ p[0] = p[1] def p_sub_param_list(p): """ sub_param_list : sub_param | sub_param_list ',' sub_param """ if len(p) == 2: p[0] = [p[1]] else: p[0] = p[1] + [p[3]] def p_sub_param(p): """ sub_param : LABEL REGISTER | REGISTER """ if len(p) == 3: p[0] = (p[1], p[2]) elif len(p) == 2: p[0] = (None, p[1]) def p_sub_result_spec(p): """ sub_result_spec : empty | '?' | sub_result_list """ if p[1] == '?': p[0] = ['A', 'X', 'Y'] # '?' means: all registers clobbered else: p[0] = p[1] def p_sub_result_list(p): """ sub_result_list : sub_result_reg | sub_result_list ',' sub_result_reg """ if len(p) == 2: p[0] = [p[1]] else: p[0] = p[1] + [p[3]] def p_sub_result_reg(p): """ sub_result_reg : REGISTER | CLOBBEREDREGISTER """ p[0] = p[1] def p_subroutine_body(p): """ subroutine_body : scope | IS INTEGER """ if len(p) == 2: p[0] = p[1] else: p[0] = p[2] def p_statement(p): """ statement : assignment ENDL | aug_assignment ENDL | subroutine_call ENDL | goto ENDL | conditional_goto ENDL | incrdecr ENDL | return ENDL """ p[0] = p[1] def p_incrdecr(p): """ incrdecr : assignment_target INCR | assignment_target DECR """ p[0] = IncrDecr(target=p[1], operator=p[2], sourceref=_token_sref(p, 2)) def p_call_subroutine(p): """ subroutine_call : calltarget preserveregs_opt '(' call_arguments_opt ')' """ p[0] = SubCall(target=p[1], preserve_regs=p[2], arguments=p[4], sourceref=_token_sref(p, 3)) def p_preserveregs_opt(p): """ preserveregs_opt : empty | preserveregs """ p[0] = p[1] def p_preserveregs(p): """ preserveregs : PRESERVEREGS """ p[0] = PreserveRegs(registers=p[1], sourceref=_token_sref(p, 1)) def p_call_arguments_opt(p): """ call_arguments_opt : empty | call_arguments """ p[0] = p[1] def p_call_arguments(p): """ call_arguments : call_argument | call_arguments ',' call_argument """ if len(p) == 2: p[0] = [p[1]] else: p[0] = p[1] + [p[3]] def p_call_argument(p): """ call_argument : expression | register IS expression | NAME IS expression """ if len(p) == 2: p[0] = CallArgument(value=p[1], sourceref=_token_sref(p, 1)) elif len(p) == 4: p[0] = CallArgument(name=p[1], value=p[3], sourceref=_token_sref(p, 1)) def p_return(p): """ return : RETURN | RETURN expression | RETURN expression ',' expression | RETURN expression ',' expression ',' expression """ if len(p) == 2: p[0] = Return(sourceref=_token_sref(p, 1)) elif len(p) == 3: p[0] = Return(value_A=p[2], sourceref=_token_sref(p, 1)) elif len(p) == 5: p[0] = Return(value_A=p[2], value_X=p[4], sourceref=_token_sref(p, 1)) elif len(p) == 7: p[0] = Return(value_A=p[2], value_X=p[4], value_Y=p[6], sourceref=_token_sref(p, 1)) def p_register(p): """ register : REGISTER """ p[0] = Register(name=p[1], sourceref=_token_sref(p, 1)) def p_goto(p): """ goto : GOTO calltarget """ p[0] = Goto(target=p[2], sourceref=_token_sref(p, 1)) def p_conditional_goto_plain(p): """ conditional_goto : IF GOTO calltarget """ p[0] = Goto(target=p[3], if_stmt=p[1], sourceref=_token_sref(p, 1)) def p_conditional_goto_expr(p): """ conditional_goto : IF expression GOTO calltarget """ p[0] = Goto(target=p[4], if_stmt=p[1], condition=p[2], sourceref=_token_sref(p, 1)) def p_calltarget(p): """ calltarget : symbolname | INTEGER | dereference """ if len(p) == 2: p[0] = CallTarget(target=p[1], address_of=False, sourceref=_token_sref(p, 1)) elif len(p) == 3: p[0] = CallTarget(target=p[2], address_of=True, sourceref=_token_sref(p, 1)) def p_dereference(p): """ dereference : '[' dereference_operand ']' """ p[0] = Dereference(location=p[2][0], datatype=p[2][1], sourceref=_token_sref(p, 1)) def p_dereference_operand(p): """ dereference_operand : symbolname type_opt | REGISTER type_opt | INTEGER type_opt """ p[0] = (p[1], p[2]) def p_symbolname(p): """ symbolname : NAME | DOTTEDNAME """ p[0] = SymbolName(name=p[1], sourceref=_token_sref(p, 1)) def p_assignment(p): """ assignment : assignment_target IS expression | assignment_target IS assignment """ p[0] = Assignment(left=[p[1]], right=p[3], sourceref=_token_sref(p, 2)) def p_aug_assignment(p): """ aug_assignment : assignment_target AUGASSIGN expression """ p[0] = AugAssignment(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) precedence = ( ('left', '+', '-'), ('left', '*', '/', 'INTEGERDIVIDE'), ('right', 'UNARY_MINUS', 'BITINVERT', "UNARY_ADDRESSOF"), ('left', "LT", "GT", "LE", "GE", "EQUALS", "NOTEQUALS"), ('nonassoc', "COMMENT"), ) def p_expression(p): """ expression : expression '+' expression | expression '-' expression | expression '*' expression | expression '/' expression | expression INTEGERDIVIDE expression | expression LT expression | expression GT expression | expression LE expression | expression GE expression | expression EQUALS expression | expression NOTEQUALS expression """ p[0] = Expression(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) def p_expression_uminus(p): """ expression : '-' expression %prec UNARY_MINUS """ p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) def p_expression_addressof(p): """ expression : BITAND symbolname %prec UNARY_ADDRESSOF """ p[0] = AddressOf(name=p[2], sourceref=_token_sref(p, 1)) def p_unary_expression_bitinvert(p): """ expression : BITINVERT expression """ p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) def p_expression_group(p): """ expression : '(' expression ')' """ p[0] = p[2] def p_expression_expr_value(p): """expression : expression_value""" p[0] = p[1] def p_expression_value(p): """ expression_value : literal_value | symbolname | register | subroutine_call | dereference """ p[0] = p[1] def p_assignment_target(p): """ assignment_target : target_registers | symbolname | dereference """ if isinstance(p[1], TargetRegisters): # if the target registers is just a single register, use that instead if len(p[1].registers) == 1: assert isinstance(p[1].registers[0], Register) p[1] = p[1].registers[0] p[0] = p[1] def p_target_registers(p): """ target_registers : register | target_registers ',' register """ if len(p) == 2: p[0] = TargetRegisters(registers=[p[1]], sourceref=_token_sref(p, 1)) else: p[1].add(p[3]) p[0] = p[1] def p_empty(p): """empty :""" pass def p_error(p): stack_state_str = ' '.join([symbol.type for symbol in parser.symstack][1:]) print('\n[ERROR DEBUG: parser state={:d} stack: {} . {} ]'.format(parser.state, stack_state_str, p)) if p: sref = SourceRef(p.lexer.source_filename, p.lineno, find_tok_column(p)) if p.value in ("", "\n"): p.lexer.error_function(sref, "syntax error before end of line") else: p.lexer.error_function(sref, "syntax error before or at '{:.20s}'", str(p.value).rstrip()) else: lexer.error_function(None, "syntax error at end of input", lexer.source_filename) def _token_sref(p, token_idx): """ Returns the coordinates for the YaccProduction object 'p' indexed with 'token_idx'. The coordinate includes the 'lineno' and 'column', starting from 1. """ last_cr = p.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx)) if last_cr < 0: last_cr = -1 chunk = p.lexer.lexdata[last_cr:p.lexpos(token_idx)] column = len(chunk.expandtabs()) return SourceRef(p.lexer.source_filename, p.lineno(token_idx), column) class TokenFilter: def __init__(self, lexer): self.lexer = lexer self.prev_was_EOL = False assert "ENDL" in tokens def token(self): # make sure we only ever emit ONE "ENDL" token in sequence if self.prev_was_EOL: # skip all EOLS that might follow while True: tok = self.lexer.token() if not tok or tok.type != "ENDL": break self.prev_was_EOL = False else: tok = self.lexer.token() self.prev_was_EOL = tok and tok.type == "ENDL" return tok parser = yacc(write_tables=True) def parse_file(filename: str, lexer_error_func=None) -> Module: lexer.error_function = lexer_error_func lexer.lineno = 1 lexer.source_filename = filename tfilter = TokenFilter(lexer) with open(filename, "rU") as inf: sourcecode = inf.read() return parser.parse(input=sourcecode, tokenfunc=tfilter.token)