From 3b0d6e969b35d35753afef78a6176e00543b4c3a Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sun, 31 Dec 2017 04:45:27 +0100 Subject: [PATCH] ast refactor --- il65/astdefs.py | 500 +++++++++++++++++++ il65/codegen.py | 209 ++++---- il65/{astparse.py => exprparse.py} | 0 il65/parse.py | 765 ++++++----------------------- il65/preprocess.py | 9 +- 5 files changed, 760 insertions(+), 723 deletions(-) create mode 100644 il65/astdefs.py rename il65/{astparse.py => exprparse.py} (100%) diff --git a/il65/astdefs.py b/il65/astdefs.py new file mode 100644 index 000000000..f99dbf9b3 --- /dev/null +++ b/il65/astdefs.py @@ -0,0 +1,500 @@ +""" +Programming Language for 6502/6510 microprocessors +These are the Abstract Syntax Tree node classes that form the Parse Tree. + +Written by Irmen de Jong (irmen@razorvine.net) +License: GNU GPL 3.0, see LICENSE +""" + +from .symbols import SourceRef, SymbolTable, SubroutineDef, SymbolDefinition, SymbolError, DataType, \ + STRING_DATATYPES, REGISTER_SYMBOLS, REGISTER_BYTES, REGISTER_SBITS, check_value_in_range +from typing import Dict, Set, List, Tuple, Optional, Union, Generator, Any + +__all__ = ["_AstNode", "Block", "Value", "IndirectValue", "IntegerValue", "FloatValue", "StringValue", "RegisterValue", + "MemMappedValue", "Comment", "Label", "AssignmentStmt", "AugmentedAssignmentStmt", "ReturnStmt", + "InplaceIncrStmt", "InplaceDecrStmt", "IfCondition", "CallStmt", "InlineAsm", "BreakpointStmt"] + + +class _AstNode: + def __init__(self, sourceref: SourceRef) -> None: + self.sourceref = sourceref.copy() + + @property + def lineref(self) -> str: + return "src l. " + str(self.sourceref.line) + + +class Block(_AstNode): + _unnamed_block_labels = {} # type: Dict[Block, str] + + def __init__(self, name: str, sourceref: SourceRef, parent_scope: SymbolTable) -> None: + super().__init__(sourceref) + self.address = 0 + self.name = name + self.statements = [] # type: List[_AstNode] + self.symbols = SymbolTable(name, parent_scope, self) + + @property + def ignore(self) -> bool: + return not self.name and not self.address + + @property + def label_names(self) -> Set[str]: + return {symbol.name for symbol in self.symbols.iter_labels()} + + @property + def label(self) -> str: + if self.name: + return self.name + if self in self._unnamed_block_labels: + return self._unnamed_block_labels[self] + label = "il65_block_{:d}".format(len(self._unnamed_block_labels)) + self._unnamed_block_labels[self] = label + return label + + def lookup(self, dottedname: str) -> Tuple[Optional['Block'], Optional[Union[SymbolDefinition, SymbolTable]]]: + # Searches a name in the current block or globally, if the name is scoped (=contains a '.'). + # Does NOT utilize a symbol table from a preprocessing parse phase, only looks in the current. + try: + scope, result = self.symbols.lookup(dottedname) + return scope.owning_block, result + except (SymbolError, LookupError): + return None, None + + def all_statements(self) -> Generator[Tuple['Block', Optional[SubroutineDef], _AstNode], None, None]: + for stmt in self.statements: + yield self, None, stmt + for sub in self.symbols.iter_subroutines(True): + for stmt in sub.sub_block.statements: + yield sub.sub_block, sub, stmt + + +class Value(_AstNode): + def __init__(self, datatype: DataType, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(sourceref) + self.datatype = datatype + self.name = name + self.constant = constant + + def assignable_from(self, other: 'Value') -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + return False, "incompatible value for assignment" + + +class IndirectValue(Value): + # only constant integers, memmapped and register values are wrapped in this. + def __init__(self, value: Value, type_modifier: DataType, sourceref: SourceRef) -> None: + assert type_modifier + super().__init__(type_modifier, sourceref, value.name, False) + self.value = value + + def __str__(self): + return "".format(self.value, self.datatype, self.name) + + def __hash__(self): + return hash((self.datatype, self.name, self.value)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, IndirectValue): + return NotImplemented + elif self is other: + return True + else: + vvo = getattr(other.value, "value", getattr(other.value, "address", None)) + vvs = getattr(self.value, "value", getattr(self.value, "address", None)) + return (other.datatype, other.name, other.value.name, other.value.datatype, other.value.constant, vvo) == \ + (self.datatype, self.name, self.value.name, self.value.datatype, self.value.constant, vvs) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + if self.datatype == DataType.BYTE: + if other.datatype == DataType.BYTE: + return True, "" + if self.datatype == DataType.WORD: + if other.datatype in {DataType.BYTE, DataType.WORD} | STRING_DATATYPES: + return True, "" + if self.datatype == DataType.FLOAT: + if other.datatype in {DataType.BYTE, DataType.WORD, DataType.FLOAT}: + return True, "" + if isinstance(other, (IntegerValue, FloatValue, StringValue)): + rangefault = check_value_in_range(self.datatype, "", 1, other.value) + if rangefault: + return False, rangefault + return True, "" + return False, "incompatible value for indirect assignment (need byte, word, float or string)" + + +class IntegerValue(Value): + def __init__(self, value: Optional[int], sourceref: SourceRef, *, datatype: DataType = None, name: str = None) -> None: + if type(value) is int: + if datatype is None: + if 0 <= value < 0x100: + datatype = DataType.BYTE + elif value < 0x10000: + datatype = DataType.WORD + else: + raise OverflowError("value too big: ${:x}".format(value)) + else: + faultreason = check_value_in_range(datatype, "", 1, value) + if faultreason: + raise OverflowError(faultreason) + super().__init__(datatype, sourceref, name, True) + self.value = value + elif value is None: + if not name: + raise ValueError("when integer value is not given, the name symbol should be speicified") + super().__init__(datatype, sourceref, name, True) + self.value = None + else: + raise TypeError("invalid data type") + + def __hash__(self): + return hash((self.datatype, self.value, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, IntegerValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) + + def __str__(self): + return "".format(self.value, self.name) + + +class FloatValue(Value): + def __init__(self, value: float, sourceref: SourceRef, name: str = None) -> None: + if type(value) is float: + super().__init__(DataType.FLOAT, sourceref, name, True) + self.value = value + else: + raise TypeError("invalid data type") + + def __hash__(self): + return hash((self.datatype, self.value, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FloatValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) + + def __str__(self): + return "".format(self.value, self.name) + + +class StringValue(Value): + def __init__(self, value: str, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(DataType.STRING, sourceref, name, constant) + self.value = value + + def __hash__(self): + return hash((self.datatype, self.value, self.name, self.constant)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, StringValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name, other.constant) == (self.datatype, self.value, self.name, self.constant) + + def __str__(self): + return "".format(self.value, self.name, self.constant) + + +class RegisterValue(Value): + def __init__(self, register: str, datatype: DataType, sourceref: SourceRef, name: str = None) -> None: + assert datatype in (DataType.BYTE, DataType.WORD) + assert register in REGISTER_SYMBOLS + super().__init__(datatype, sourceref, name, False) + self.register = register + + def __hash__(self): + return hash((self.datatype, self.register, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RegisterValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.register, other.name) == (self.datatype, self.register, self.name) + + def __str__(self): + return "".format(self.register, self.datatype, self.name) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if isinstance(other, IndirectValue): + if self.datatype == DataType.BYTE: + if other.datatype == DataType.BYTE: + return True, "" + return False, "(unsigned) byte required" + if self.datatype == DataType.WORD: + if other.datatype in (DataType.BYTE, DataType.WORD): + return True, "" + return False, "(unsigned) byte required" + return False, "incompatible indirect value for register assignment" + if self.register in ("SC", "SI"): + if isinstance(other, IntegerValue) and other.value in (0, 1): + return True, "" + return False, "can only assign an integer constant value of 0 or 1 to SC and SI" + if self.constant: + return False, "cannot assign to a constant" + if isinstance(other, RegisterValue): + if other.register in {"SI", "SC", "SZ"}: + return False, "cannot explicitly assign from a status bit register alias" + if len(self.register) < len(other.register): + return False, "register size mismatch" + if isinstance(other, StringValue) and self.register in REGISTER_BYTES | REGISTER_SBITS: + return False, "string address requires 16 bits combined register" + if isinstance(other, IntegerValue): + if other.value is not None: + range_error = check_value_in_range(self.datatype, self.register, 1, other.value) + if range_error: + return False, range_error + return True, "" + if self.datatype == DataType.WORD: + return True, "" + return False, "cannot assign address to single register" + if isinstance(other, FloatValue): + range_error = check_value_in_range(self.datatype, self.register, 1, other.value) + if range_error: + return False, range_error + return True, "" + if self.datatype == DataType.BYTE: + if other.datatype != DataType.BYTE: + return False, "(unsigned) byte required" + return True, "" + if self.datatype == DataType.WORD: + if other.datatype in (DataType.BYTE, DataType.WORD) or other.datatype in STRING_DATATYPES: + return True, "" + return False, "(unsigned) byte, word or string required" + return False, "incompatible value for register assignment" + + +class MemMappedValue(Value): + def __init__(self, address: Optional[int], datatype: DataType, length: int, + sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(datatype, sourceref, name, constant) + self.address = address + self.length = length + assert address is None or type(address) is int + + def __hash__(self): + return hash((self.datatype, self.address, self.length, self.name, self.constant)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, MemMappedValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.address, other.length, other.name, other.constant) == \ + (self.datatype, self.address, self.length, self.name, self.constant) + + def __str__(self): + addr = "" if self.address is None else "${:04x}".format(self.address) + return "" \ + .format(addr, self.datatype, self.length, self.name, self.constant) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + if isinstance(other, IndirectValue): + return False, "can not yet assign memory mapped value from indirect value" # @todo indirect v assign + if self.datatype == DataType.BYTE: + if isinstance(other, (IntegerValue, RegisterValue, MemMappedValue)): + if other.datatype == DataType.BYTE: + return True, "" + return False, "(unsigned) byte required" + elif isinstance(other, FloatValue): + range_error = check_value_in_range(self.datatype, "", 1, other.value) + if range_error: + return False, range_error + return True, "" + else: + return False, "(unsigned) byte required" + elif self.datatype in (DataType.WORD, DataType.FLOAT): + if isinstance(other, (IntegerValue, FloatValue)): + range_error = check_value_in_range(self.datatype, "", 1, other.value) + if range_error: + return False, range_error + return True, "" + elif isinstance(other, (RegisterValue, MemMappedValue)): + if other.datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): + return True, "" + else: + return False, "byte or word or float required" + elif isinstance(other, StringValue): + if self.datatype == DataType.WORD: + return True, "" + return False, "string address requires 16 bits (a word)" + if self.datatype == DataType.BYTE: + return False, "(unsigned) byte required" + if self.datatype == DataType.WORD: + return False, "(unsigned) word required" + return False, "incompatible value for assignment" + + +class Comment(_AstNode): + def __init__(self, text: str, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.text = text + + +class Label(_AstNode): + def __init__(self, name: str, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.name = name + + +class AssignmentStmt(_AstNode): + def __init__(self, leftvalues: List[Value], right: Value, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.leftvalues = leftvalues + self.right = right + + def __str__(self): + return "".format(str(self.right), ",".join(str(lv) for lv in self.leftvalues)) + + _immediate_string_vars = {} # type: Dict[str, Tuple[str, str]] + + def desugar_immediate_string(self, containing_block: Block) -> None: + if self.right.name or not isinstance(self.right, StringValue): + return + if self.right.value in self._immediate_string_vars: + blockname, stringvar_name = self._immediate_string_vars[self.right.value] + if blockname: + self.right.name = blockname + '.' + stringvar_name + else: + self.right.name = stringvar_name + else: + stringvar_name = "il65_str_{:d}".format(id(self)) + value = self.right.value + containing_block.symbols.define_variable(stringvar_name, self.sourceref, DataType.STRING, value=value) + self.right.name = stringvar_name + self._immediate_string_vars[self.right.value] = (containing_block.name, stringvar_name) + + def remove_identity_lvalues(self) -> None: + for lv in self.leftvalues: + if lv == self.right: + print("{}: removed identity assignment".format(self.sourceref)) + remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right] + self.leftvalues = remaining_leftvalues + + def is_identity(self) -> bool: + return all(lv == self.right for lv in self.leftvalues) + + +class AugmentedAssignmentStmt(AssignmentStmt): + SUPPORTED_OPERATORS = {"+=", "-=", "&=", "|=", "^=", ">>=", "<<="} + + # full set: {"+=", "-=", "*=", "/=", "%=", "//=", "**=", "&=", "|=", "^=", ">>=", "<<="} + + def __init__(self, left: Value, operator: str, right: Value, sourceref: SourceRef) -> None: + assert operator in self.SUPPORTED_OPERATORS + super().__init__([left], right, sourceref) + self.operator = operator + + def __str__(self): + return "".format(str(self.leftvalues[0]), self.operator, str(self.right)) + + +class ReturnStmt(_AstNode): + def __init__(self, sourceref: SourceRef, a: Optional[Value] = None, + x: Optional[Value] = None, + y: Optional[Value] = None) -> None: + super().__init__(sourceref) + self.a = a + self.x = x + self.y = y + + +class InplaceIncrStmt(_AstNode): + def __init__(self, what: Value, howmuch: Union[int, float], sourceref: SourceRef) -> None: + super().__init__(sourceref) + assert howmuch > 0 + self.what = what + self.howmuch = howmuch + + +class InplaceDecrStmt(_AstNode): + def __init__(self, what: Value, howmuch: Union[int, float], sourceref: SourceRef) -> None: + super().__init__(sourceref) + assert howmuch > 0 + self.what = what + self.howmuch = howmuch + + +class IfCondition(_AstNode): + SWAPPED_OPERATOR = {"==": "==", + "!=": "!=", + "<=": ">=", + ">=": "<=", + "<": ">", + ">": "<"} + IF_STATUSES = {"cc", "cs", "vc", "vs", "eq", "ne", "true", "not", "zero", "pos", "neg", "lt", "gt", "le", "ge"} + + def __init__(self, ifstatus: str, leftvalue: Optional[Value], + operator: str, rightvalue: Optional[Value], sourceref: SourceRef) -> None: + assert ifstatus in self.IF_STATUSES + assert operator in (None, "") or operator in self.SWAPPED_OPERATOR + if operator: + assert ifstatus in ("true", "not", "zero") + super().__init__(sourceref) + self.ifstatus = ifstatus + self.lvalue = leftvalue + self.comparison_op = operator + self.rvalue = rightvalue + + def __str__(self): + return "".format(self.ifstatus, self.lvalue, self.comparison_op, self.rvalue) + + def make_if_true(self) -> bool: + # makes a condition of the form if_not a < b into: if a > b (gets rid of the not) + # returns whether the change was made or not + if self.ifstatus == "not" and self.comparison_op: + self.ifstatus = "true" + self.comparison_op = self.SWAPPED_OPERATOR[self.comparison_op] + return True + return False + + def swap(self) -> Tuple[Value, str, Value]: + self.lvalue, self.comparison_op, self.rvalue = self.rvalue, self.SWAPPED_OPERATOR[self.comparison_op], self.lvalue + return self.lvalue, self.comparison_op, self.rvalue + + +class CallStmt(_AstNode): + def __init__(self, sourceref: SourceRef, target: Optional[Value] = None, *, + address: Optional[int] = None, arguments: List[Tuple[str, Any]] = None, + outputs: List[Tuple[str, Value]] = None, is_goto: bool = False, + preserve_regs: bool = True, condition: IfCondition = None) -> None: + if not is_goto: + assert condition is None + super().__init__(sourceref) + self.target = target + self.address = address + self.arguments = arguments + self.outputvars = outputs + self.is_goto = is_goto + self.condition = condition + self.preserve_regs = preserve_regs + self.desugared_call_arguments = [] # type: List[AssignmentStmt] + self.desugared_output_assignments = [] # type: List[AssignmentStmt] + + +class InlineAsm(_AstNode): + def __init__(self, asmlines: List[str], sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.asmlines = asmlines + + +class BreakpointStmt(_AstNode): + def __init__(self, sourceref: SourceRef) -> None: + super().__init__(sourceref) diff --git a/il65/codegen.py b/il65/codegen.py index bee47b337..f74a8bf7c 100644 --- a/il65/codegen.py +++ b/il65/codegen.py @@ -15,6 +15,7 @@ import contextlib from functools import partial from typing import TextIO, Set, Union, List, Callable from .parse import ProgramFormat, ParseResult, Parser +from .astdefs import * from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \ STRING_DATATYPES, REGISTER_WORDS, REGISTER_BYTES, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE @@ -32,7 +33,7 @@ class CodeGenerator: self.generated_code = io.StringIO() self.p = partial(print, file=self.generated_code) self.previous_stmt_was_assignment = False - self.cur_block = None # type: ParseResult.Block + self.cur_block = None # type: Block def generate(self) -> None: print("\ngenerating assembly code") @@ -55,7 +56,7 @@ class CodeGenerator: if zpblock.label_names: raise CodeError("ZP block cannot contain labels") # can only contain code comments, or nothing at all - if not all(isinstance(s, ParseResult.Comment) for s in zpblock.statements): + if not all(isinstance(s, Comment) for s in zpblock.statements): raise CodeError("ZP block cannot contain code statements, only definitions and comments") def optimize(self) -> None: @@ -190,7 +191,7 @@ class CodeGenerator: for block in [b for b in self.parsed.blocks if b.name == "
"]: self.cur_block = block for s in block.statements: - if isinstance(s, ParseResult.Comment): + if isinstance(s, Comment): self.p(s.text) else: raise CodeError("header block cannot contain any other statements beside comments") @@ -200,7 +201,7 @@ class CodeGenerator: self.cur_block = zpblock self.p("\n; ---- zero page block: '{:s}' ----\t\t; src l. {:d}\n".format(zpblock.sourceref.file, zpblock.sourceref.line)) for s in zpblock.statements: - if isinstance(s, ParseResult.Comment): + if isinstance(s, Comment): self.p(s.text) else: raise CodeError("zp cannot contain any other statements beside comments") @@ -211,13 +212,13 @@ class CodeGenerator: block = self.parsed.find_block("main") statements = list(block.statements) for index, stmt in enumerate(statements): - if isinstance(stmt, ParseResult.Label) and stmt.name == "start": + if isinstance(stmt, Label) and stmt.name == "start": asmlines = [ "\t\tcld\t\t\t; clear decimal flag", "\t\tclc\t\t\t; clear carry flag", "\t\tclv\t\t\t; clear overflow flag", ] - statements.insert(index+1, ParseResult.InlineAsm(asmlines, stmt.sourceref)) + statements.insert(index+1, InlineAsm(asmlines, stmt.sourceref)) break block.statements = statements # generate @@ -260,7 +261,7 @@ class CodeGenerator: self.p("; end external subroutines") self.p("\t.pend\n") - def generate_block_vars(self, block: ParseResult.Block) -> None: + def generate_block_vars(self, block: Block) -> None: consts = [c for c in block.symbols.iter_constants()] if consts: self.p("; constants") @@ -351,49 +352,49 @@ class CodeGenerator: self.p("{:s}\n\t\t.ptext {:s}".format(vardef.name, self.output_string(str(vardef.value), True))) self.p(".enc 'none'") - def generate_statement(self, stmt: ParseResult._AstNode) -> None: - if isinstance(stmt, ParseResult.ReturnStmt): + def generate_statement(self, stmt: _AstNode) -> None: + if isinstance(stmt, ReturnStmt): if stmt.a: - if isinstance(stmt.a, ParseResult.IntegerValue): + if isinstance(stmt.a, IntegerValue): self.p("\t\tlda #{:d}".format(stmt.a.value)) else: raise CodeError("can only return immediate values for now") # XXX if stmt.x: - if isinstance(stmt.x, ParseResult.IntegerValue): + if isinstance(stmt.x, IntegerValue): self.p("\t\tldx #{:d}".format(stmt.x.value)) else: raise CodeError("can only return immediate values for now") # XXX if stmt.y: - if isinstance(stmt.y, ParseResult.IntegerValue): + if isinstance(stmt.y, IntegerValue): self.p("\t\tldy #{:d}".format(stmt.y.value)) else: raise CodeError("can only return immediate values for now") # XXX self.p("\t\trts") - elif isinstance(stmt, ParseResult.AugmentedAssignmentStmt): + elif isinstance(stmt, AugmentedAssignmentStmt): self.generate_augmented_assignment(stmt) - elif isinstance(stmt, ParseResult.AssignmentStmt): + elif isinstance(stmt, AssignmentStmt): self.generate_assignment(stmt) - elif isinstance(stmt, ParseResult.Label): + elif isinstance(stmt, Label): self.p("\n{:s}\t\t\t\t; {:s}".format(stmt.name, stmt.lineref)) - elif isinstance(stmt, (ParseResult.InplaceIncrStmt, ParseResult.InplaceDecrStmt)): + elif isinstance(stmt, (InplaceIncrStmt, InplaceDecrStmt)): self.generate_incr_or_decr(stmt) - elif isinstance(stmt, ParseResult.CallStmt): + elif isinstance(stmt, CallStmt): self.generate_call(stmt) - elif isinstance(stmt, ParseResult.InlineAsm): + elif isinstance(stmt, InlineAsm): self.p("\t\t; inline asm, " + stmt.lineref) for line in stmt.asmlines: self.p(line) self.p("\t\t; end inline asm, " + stmt.lineref) - elif isinstance(stmt, ParseResult.Comment): + elif isinstance(stmt, Comment): self.p(stmt.text) - elif isinstance(stmt, ParseResult.BreakpointStmt): + elif isinstance(stmt, BreakpointStmt): # put a marker in the source so that we can generate a list of breakpoints later self.p("\t\tnop\t; {:s} {:s}".format(self.BREAKPOINT_COMMENT_SIGNATURE, stmt.lineref)) else: raise CodeError("unknown statement " + repr(stmt)) - self.previous_stmt_was_assignment = isinstance(stmt, ParseResult.AssignmentStmt) + self.previous_stmt_was_assignment = isinstance(stmt, AssignmentStmt) - def generate_incr_or_decr(self, stmt: Union[ParseResult.InplaceIncrStmt, ParseResult.InplaceDecrStmt]) -> None: + def generate_incr_or_decr(self, stmt: Union[InplaceIncrStmt, InplaceDecrStmt]) -> None: if stmt.what.datatype == DataType.FLOAT: raise CodeError("incr/decr on float not yet supported") # @todo support incr/decr on float else: @@ -401,8 +402,8 @@ class CodeGenerator: assert stmt.howmuch > 0 if stmt.howmuch > 0xff: raise CodeError("only supports incr/decr by up to 255 for now") # XXX - is_incr = isinstance(stmt, ParseResult.InplaceIncrStmt) - if isinstance(stmt.what, ParseResult.RegisterValue): + is_incr = isinstance(stmt, InplaceIncrStmt) + if isinstance(stmt.what, RegisterValue): reg = stmt.what.register # note: these operations below are all checked to be ok if is_incr: @@ -505,10 +506,10 @@ class CodeGenerator: self.p("+\t\tpla") else: raise CodeError("invalid decr register: " + reg) - elif isinstance(stmt.what, (ParseResult.MemMappedValue, ParseResult.IndirectValue)): + elif isinstance(stmt.what, (MemMappedValue, IndirectValue)): what = stmt.what - if isinstance(what, ParseResult.IndirectValue): - if isinstance(what.value, ParseResult.IntegerValue): + if isinstance(what, IndirectValue): + if isinstance(what.value, IntegerValue): r_str = what.value.name or Parser.to_hex(what.value.value) else: raise CodeError("invalid incr indirect type", what.value) @@ -568,7 +569,7 @@ class CodeGenerator: else: raise CodeError("cannot in/decrement " + str(stmt.what)) - def generate_call(self, stmt: ParseResult.CallStmt) -> None: + def generate_call(self, stmt: CallStmt) -> None: self.p("\t\t\t\t\t; " + stmt.lineref) if stmt.condition: assert stmt.is_goto @@ -667,7 +668,7 @@ class CodeGenerator: raise CodeError("invalid if status " + ifs) self._generate_call_or_goto(stmt, branch_emitter) - def _generate_goto_conditional_truthvalue(self, stmt: ParseResult.CallStmt) -> None: + def _generate_goto_conditional_truthvalue(self, stmt: CallStmt) -> None: # the condition is just the 'truth value' of the single value, # this is translated into assembly by comparing the argument to zero. def branch_emitter_mmap(targetstr: str, is_goto: bool, target_indirect: bool) -> None: @@ -677,7 +678,7 @@ class CodeGenerator: assert stmt.condition.ifstatus in ("true", "not", "zero") branch, inverse_branch = ("bne", "beq") if stmt.condition.ifstatus == "true" else ("beq", "bne") cv = stmt.condition.lvalue - assert isinstance(cv, ParseResult.MemMappedValue) + assert isinstance(cv, MemMappedValue) cv_str = cv.name or Parser.to_hex(cv.address) if cv.datatype == DataType.BYTE: self.p("\t\tsta " + Parser.to_hex(Zeropage.SCRATCH_B1)) # need to save A, because the goto may not be taken @@ -709,7 +710,7 @@ class CodeGenerator: branch, inverse_branch = ("bne", "beq") if stmt.condition.ifstatus == "true" else ("beq", "bne") line_after_branch = "" cv = stmt.condition.lvalue - assert isinstance(cv, ParseResult.RegisterValue) + assert isinstance(cv, RegisterValue) if cv.register == 'A': self.p("\t\tcmp #0") elif cv.register == 'X': @@ -744,7 +745,7 @@ class CodeGenerator: assert stmt.condition.ifstatus in ("true", "not", "zero") assert not target_indirect cv = stmt.condition.lvalue.value # type: ignore - if isinstance(cv, ParseResult.RegisterValue): + if isinstance(cv, RegisterValue): branch = "bne" if stmt.condition.ifstatus == "true" else "beq" self.p("\t\tsta " + Parser.to_hex(Zeropage.SCRATCH_B1)) # need to save A, because the goto may not be taken if cv.register == 'Y': @@ -761,9 +762,9 @@ class CodeGenerator: self.p("+\t\tlda $ffff") self.p("\t\t{:s} {:s}".format(branch, targetstr)) self.p("\t\tlda " + Parser.to_hex(Zeropage.SCRATCH_B1)) # restore A - elif isinstance(cv, ParseResult.MemMappedValue): + elif isinstance(cv, MemMappedValue): raise CodeError("memmapped indirect should not occur, use the variable without indirection") - elif isinstance(cv, ParseResult.IntegerValue) and cv.constant: + elif isinstance(cv, IntegerValue) and cv.constant: branch, inverse_branch = ("bne", "beq") if stmt.condition.ifstatus == "true" else ("beq", "bne") cv_str = cv.name or Parser.to_hex(cv.value) if cv.datatype == DataType.BYTE: @@ -791,23 +792,23 @@ class CodeGenerator: raise CodeError("weird indirect type", str(cv)) cv = stmt.condition.lvalue - if isinstance(cv, ParseResult.RegisterValue): + if isinstance(cv, RegisterValue): self._generate_call_or_goto(stmt, branch_emitter_reg) - elif isinstance(cv, ParseResult.MemMappedValue): + elif isinstance(cv, MemMappedValue): self._generate_call_or_goto(stmt, branch_emitter_mmap) - elif isinstance(cv, ParseResult.IndirectValue): - if isinstance(cv.value, ParseResult.RegisterValue): + elif isinstance(cv, IndirectValue): + if isinstance(cv.value, RegisterValue): self._generate_call_or_goto(stmt, branch_emitter_indirect_cond) - elif isinstance(cv.value, ParseResult.MemMappedValue): + elif isinstance(cv.value, MemMappedValue): self._generate_call_or_goto(stmt, branch_emitter_indirect_cond) - elif isinstance(cv.value, ParseResult.IntegerValue) and cv.value.constant: + elif isinstance(cv.value, IntegerValue) and cv.value.constant: self._generate_call_or_goto(stmt, branch_emitter_indirect_cond) else: raise CodeError("weird indirect type", str(cv)) else: raise CodeError("need register, memmapped or indirect value", str(cv)) - def _generate_goto_conditional_comparison(self, stmt: ParseResult.CallStmt) -> None: + def _generate_goto_conditional_comparison(self, stmt: CallStmt) -> None: # the condition is lvalue operator rvalue raise NotImplementedError("no comparisons yet") # XXX comparisons assert stmt.condition.ifstatus in ("true", "not", "zero") @@ -816,13 +817,13 @@ class CodeGenerator: if lv.constant and not rv.constant: # if lv is a constant, swap the whole thing around so the constant is on the right lv, compare_operator, rv = stmt.condition.swap() - if isinstance(rv, ParseResult.RegisterValue): + if isinstance(rv, RegisterValue): # if rv is a register, make sure it comes first instead lv, compare_operator, rv = stmt.condition.swap() if lv.datatype != DataType.BYTE or rv.datatype != DataType.BYTE: raise CodeError("can only generate comparison code for byte values for now") # @todo compare non-bytes - if isinstance(lv, ParseResult.RegisterValue): - if isinstance(rv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): + if isinstance(rv, RegisterValue): self.p("\t\tst{:s} {:s}".format(rv.register.lower(), Parser.to_hex(Zeropage.SCRATCH_B1))) if lv.register == "A": self.p("\t\tcmp " + Parser.to_hex(Zeropage.SCRATCH_B1)) @@ -832,7 +833,7 @@ class CodeGenerator: self.p("\t\tcpy " + Parser.to_hex(Zeropage.SCRATCH_B1)) else: raise CodeError("wrong lvalue register") - elif isinstance(rv, ParseResult.IntegerValue): + elif isinstance(rv, IntegerValue): rvstr = rv.name or Parser.to_hex(rv.value) if lv.register == "A": self.p("\t\tcmp #" + rvstr) @@ -842,7 +843,7 @@ class CodeGenerator: self.p("\t\tcpy #" + rvstr) else: raise CodeError("wrong lvalue register") - elif isinstance(rv, ParseResult.MemMappedValue): + elif isinstance(rv, MemMappedValue): rvstr = rv.name or Parser.to_hex(rv.address) if lv.register == "A": self.p("\t\tcmp " + rvstr) @@ -854,14 +855,14 @@ class CodeGenerator: raise CodeError("wrong lvalue register") else: raise CodeError("invalid rvalue type in comparison", rv) - elif isinstance(lv, ParseResult.MemMappedValue): - assert not isinstance(rv, ParseResult.RegisterValue), "registers as rvalue should have been swapped with lvalue" - if isinstance(rv, ParseResult.IntegerValue): + elif isinstance(lv, MemMappedValue): + assert not isinstance(rv, RegisterValue), "registers as rvalue should have been swapped with lvalue" + if isinstance(rv, IntegerValue): self.p("\t\tsta " + Parser.to_hex(Zeropage.SCRATCH_B1)) # need to save A, because the goto may not be taken self.p("\t\tlda " + (lv.name or Parser.to_hex(lv.address))) self.p("\t\tcmp #" + (rv.name or Parser.to_hex(rv.value))) line_after_goto = "\t\tlda " + Parser.to_hex(Zeropage.SCRATCH_B1) # restore A - elif isinstance(rv, ParseResult.MemMappedValue): + elif isinstance(rv, MemMappedValue): rvstr = rv.name or Parser.to_hex(rv.address) self.p("\t\tsta " + Parser.to_hex(Zeropage.SCRATCH_B1)) # need to save A, because the goto may not be taken self.p("\t\tlda " + (lv.name or Parser.to_hex(lv.address))) @@ -872,7 +873,7 @@ class CodeGenerator: else: raise CodeError("invalid lvalue type in comparison", lv) - def _generate_call_or_goto(self, stmt: ParseResult.CallStmt, branch_emitter: Callable[[str, bool, bool], None]) -> None: + def _generate_call_or_goto(self, stmt: CallStmt, branch_emitter: Callable[[str, bool, bool], None]) -> None: def generate_param_assignments() -> None: for assign_stmt in stmt.desugared_call_arguments: self.generate_assignment(assign_stmt) @@ -884,15 +885,15 @@ class CodeGenerator: def params_load_a() -> bool: for assign_stmt in stmt.desugared_call_arguments: for lv in assign_stmt.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): if lv.register == 'A': return True return False - def unclobber_result_registers(registers: Set[str], output_assignments: List[ParseResult.AssignmentStmt]) -> None: + def unclobber_result_registers(registers: Set[str], output_assignments: List[AssignmentStmt]) -> None: for a in output_assignments: for lv in a.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): if len(lv.register) == 1: registers.discard(lv.register) else: @@ -905,7 +906,7 @@ class CodeGenerator: symblock = None targetdef = None if isinstance(targetdef, SubroutineDef): - if isinstance(stmt.target, ParseResult.MemMappedValue): + if isinstance(stmt.target, MemMappedValue): targetstr = stmt.target.name or Parser.to_hex(stmt.address) else: raise CodeError("call sub target should be mmapped") @@ -924,16 +925,16 @@ class CodeGenerator: branch_emitter(targetstr, False, False) generate_result_assignments() return - if isinstance(stmt.target, ParseResult.IndirectValue): + if isinstance(stmt.target, IndirectValue): if stmt.target.name: targetstr = stmt.target.name elif stmt.address is not None: targetstr = Parser.to_hex(stmt.address) elif stmt.target.value.name: targetstr = stmt.target.value.name - elif isinstance(stmt.target.value, ParseResult.RegisterValue): + elif isinstance(stmt.target.value, RegisterValue): targetstr = stmt.target.value.register - elif isinstance(stmt.target.value, ParseResult.IntegerValue): + elif isinstance(stmt.target.value, IntegerValue): targetstr = stmt.target.value.name or Parser.to_hex(stmt.target.value.value) else: raise CodeError("missing name", stmt.target.value) @@ -974,7 +975,7 @@ class CodeGenerator: targetstr = stmt.target.name elif stmt.address is not None: targetstr = Parser.to_hex(stmt.address) - elif isinstance(stmt.target, ParseResult.IntegerValue): + elif isinstance(stmt.target, IntegerValue): targetstr = stmt.target.name or Parser.to_hex(stmt.target.value) else: raise CodeError("missing name", stmt.target) @@ -991,24 +992,24 @@ class CodeGenerator: branch_emitter(targetstr, False, False) generate_result_assignments() - def generate_augmented_assignment(self, stmt: ParseResult.AugmentedAssignmentStmt) -> None: + def generate_augmented_assignment(self, stmt: AugmentedAssignmentStmt) -> None: # for instance: value += 3 lvalue = stmt.leftvalues[0] rvalue = stmt.right self.p("\t\t\t\t\t; " + stmt.lineref) - if isinstance(lvalue, ParseResult.RegisterValue): - if isinstance(rvalue, ParseResult.IntegerValue): + if isinstance(lvalue, RegisterValue): + if isinstance(rvalue, IntegerValue): self._generate_aug_reg_int(lvalue, stmt.operator, rvalue) - elif isinstance(rvalue, ParseResult.RegisterValue): + elif isinstance(rvalue, RegisterValue): self._generate_aug_reg_reg(lvalue, stmt.operator, rvalue) - elif isinstance(rvalue, ParseResult.MemMappedValue): + elif isinstance(rvalue, MemMappedValue): self._generate_aug_reg_mem(lvalue, stmt.operator, rvalue) else: raise CodeError("invalid rvalue for augmented assignment on register", str(rvalue)) else: raise CodeError("augmented assignment only implemented for registers for now") # XXX - def _generate_aug_reg_mem(self, lvalue: ParseResult.RegisterValue, operator: str, rvalue: ParseResult.MemMappedValue) -> None: + def _generate_aug_reg_mem(self, lvalue: RegisterValue, operator: str, rvalue: MemMappedValue) -> None: r_str = rvalue.name or Parser.to_hex(rvalue.address) if operator == "+=": if lvalue.register == "A": @@ -1106,7 +1107,7 @@ class CodeGenerator: elif operator == "<<=": raise CodeError("can not yet shift a variable amount") # XXX - def _generate_aug_reg_int(self, lvalue: ParseResult.RegisterValue, operator: str, rvalue: ParseResult.IntegerValue) -> None: + def _generate_aug_reg_int(self, lvalue: RegisterValue, operator: str, rvalue: IntegerValue) -> None: r_str = rvalue.name or Parser.to_hex(rvalue.value) if operator == "+=": if lvalue.register == "A": @@ -1238,7 +1239,7 @@ class CodeGenerator: else: raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo <<=.word - def _generate_aug_reg_reg(self, lvalue: ParseResult.RegisterValue, operator: str, rvalue: ParseResult.RegisterValue) -> None: + def _generate_aug_reg_reg(self, lvalue: RegisterValue, operator: str, rvalue: RegisterValue) -> None: if operator == "+=": if rvalue.register not in REGISTER_BYTES: raise CodeError("unsupported rvalue register for aug assign", str(rvalue)) # @todo +=.word @@ -1400,55 +1401,55 @@ class CodeGenerator: else: raise CodeError("unsupported lvalue register for aug assign", str(lvalue)) # @todo <<=.word - def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None: - def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue: - if isinstance(iv.value, ParseResult.MemMappedValue): + def generate_assignment(self, stmt: AssignmentStmt) -> None: + def unwrap_indirect(iv: IndirectValue) -> MemMappedValue: + if isinstance(iv.value, MemMappedValue): return iv.value - elif iv.value.constant and isinstance(iv.value, ParseResult.IntegerValue): - return ParseResult.MemMappedValue(iv.value.value, iv.datatype, 1, stmt.sourceref, iv.name) + elif iv.value.constant and isinstance(iv.value, IntegerValue): + return MemMappedValue(iv.value.value, iv.datatype, 1, stmt.sourceref, iv.name) else: raise CodeError("cannot yet generate code for assignment: non-constant and non-memmapped indirect") # XXX rvalue = stmt.right - if isinstance(rvalue, ParseResult.IndirectValue): + if isinstance(rvalue, IndirectValue): rvalue = unwrap_indirect(rvalue) self.p("\t\t\t\t\t; " + stmt.lineref) - if isinstance(rvalue, ParseResult.IntegerValue): + if isinstance(rvalue, IntegerValue): for lv in stmt.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): self.generate_assign_integer_to_reg(lv.register, rvalue) - elif isinstance(lv, ParseResult.MemMappedValue): + elif isinstance(lv, MemMappedValue): self.generate_assign_integer_to_mem(lv, rvalue) - elif isinstance(lv, ParseResult.IndirectValue): + elif isinstance(lv, IndirectValue): lv = unwrap_indirect(lv) self.generate_assign_integer_to_mem(lv, rvalue) else: raise CodeError("invalid assignment target (1)", str(stmt)) - elif isinstance(rvalue, ParseResult.RegisterValue): + elif isinstance(rvalue, RegisterValue): for lv in stmt.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): self.generate_assign_reg_to_reg(lv, rvalue.register) - elif isinstance(lv, ParseResult.MemMappedValue): + elif isinstance(lv, MemMappedValue): self.generate_assign_reg_to_memory(lv, rvalue.register) - elif isinstance(lv, ParseResult.IndirectValue): + elif isinstance(lv, IndirectValue): lv = unwrap_indirect(lv) self.generate_assign_reg_to_memory(lv, rvalue.register) else: raise CodeError("invalid assignment target (2)", str(stmt)) - elif isinstance(rvalue, ParseResult.StringValue): + elif isinstance(rvalue, StringValue): r_str = self.output_string(rvalue.value, True) for lv in stmt.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): if len(rvalue.value) == 1: self.generate_assign_char_to_reg(lv, r_str) else: self.generate_assign_string_to_reg(lv, rvalue) - elif isinstance(lv, ParseResult.MemMappedValue): + elif isinstance(lv, MemMappedValue): if len(rvalue.value) == 1: self.generate_assign_char_to_memory(lv, r_str) else: self.generate_assign_string_to_memory(lv, rvalue) - elif isinstance(lv, ParseResult.IndirectValue): + elif isinstance(lv, IndirectValue): lv = unwrap_indirect(lv) if len(rvalue.value) == 1: self.generate_assign_char_to_memory(lv, r_str) @@ -1456,22 +1457,22 @@ class CodeGenerator: self.generate_assign_string_to_memory(lv, rvalue) else: raise CodeError("invalid assignment target (2)", str(stmt)) - elif isinstance(rvalue, ParseResult.MemMappedValue): + elif isinstance(rvalue, MemMappedValue): for lv in stmt.leftvalues: - if isinstance(lv, ParseResult.RegisterValue): + if isinstance(lv, RegisterValue): self.generate_assign_mem_to_reg(lv.register, rvalue) - elif isinstance(lv, ParseResult.MemMappedValue): + elif isinstance(lv, MemMappedValue): self.generate_assign_mem_to_mem(lv, rvalue) - elif isinstance(lv, ParseResult.IndirectValue): + elif isinstance(lv, IndirectValue): lv = unwrap_indirect(lv) self.generate_assign_mem_to_mem(lv, rvalue) else: raise CodeError("invalid assignment target (4)", str(stmt)) - elif isinstance(rvalue, ParseResult.FloatValue): + elif isinstance(rvalue, FloatValue): for lv in stmt.leftvalues: - if isinstance(lv, ParseResult.MemMappedValue) and lv.datatype == DataType.FLOAT: + if isinstance(lv, MemMappedValue) and lv.datatype == DataType.FLOAT: self.generate_assign_float_to_mem(lv, rvalue) - elif isinstance(lv, ParseResult.IndirectValue): + elif isinstance(lv, IndirectValue): lv = unwrap_indirect(lv) assert lv.datatype == DataType.FLOAT self.generate_assign_float_to_mem(lv, rvalue) @@ -1480,8 +1481,8 @@ class CodeGenerator: else: raise CodeError("invalid assignment value type", str(stmt)) - def generate_assign_float_to_mem(self, mmv: ParseResult.MemMappedValue, - rvalue: Union[ParseResult.FloatValue, ParseResult.IntegerValue]) -> None: + def generate_assign_float_to_mem(self, mmv: MemMappedValue, + rvalue: Union[FloatValue, IntegerValue]) -> None: floatvalue = float(rvalue.value) mflpt = self.to_mflpt5(floatvalue) target = mmv.name or Parser.to_hex(mmv.address) @@ -1494,7 +1495,7 @@ class CodeGenerator: self.p("\t\tsta {:s}+{:d}".format(target, i)) self.p("\t\tpla") - def generate_assign_reg_to_memory(self, lv: ParseResult.MemMappedValue, r_register: str) -> None: + def generate_assign_reg_to_memory(self, lv: MemMappedValue, r_register: str) -> None: # Memory = Register lv_string = lv.name or Parser.to_hex(lv.address) if lv.datatype == DataType.BYTE: @@ -1554,7 +1555,7 @@ class CodeGenerator: else: raise CodeError("invalid lvalue type", lv.datatype) - def generate_assign_reg_to_reg(self, lv: ParseResult.RegisterValue, r_register: str) -> None: + def generate_assign_reg_to_reg(self, lv: RegisterValue, r_register: str) -> None: if lv.register != r_register: if lv.register == 'A': # x/y -> a self.p("\t\tt{:s}a".format(r_register.lower())) @@ -1669,7 +1670,7 @@ class CodeGenerator: else: yield - def generate_assign_integer_to_mem(self, lv: ParseResult.MemMappedValue, rvalue: ParseResult.IntegerValue) -> None: + def generate_assign_integer_to_mem(self, lv: MemMappedValue, rvalue: IntegerValue) -> None: if lv.name: symblock, sym = self.cur_block.lookup(lv.name) if not isinstance(sym, VariableDef): @@ -1701,7 +1702,7 @@ class CodeGenerator: else: raise CodeError("invalid lvalue type " + str(lvdatatype)) - def generate_assign_mem_to_reg(self, l_register: str, rvalue: ParseResult.MemMappedValue) -> None: + def generate_assign_mem_to_reg(self, l_register: str, rvalue: MemMappedValue) -> None: r_str = rvalue.name if rvalue.name else "${:x}".format(rvalue.address) if len(l_register) == 1: if rvalue.datatype != DataType.BYTE: @@ -1717,7 +1718,7 @@ class CodeGenerator: else: raise CodeError("can only assign a byte or word to a register pair") - def generate_assign_mem_to_mem(self, lv: ParseResult.MemMappedValue, rvalue: ParseResult.MemMappedValue) -> None: + def generate_assign_mem_to_mem(self, lv: MemMappedValue, rvalue: MemMappedValue) -> None: r_str = rvalue.name or Parser.to_hex(rvalue.address) l_str = lv.name or Parser.to_hex(lv.address) if lv.datatype == DataType.BYTE: @@ -1771,7 +1772,7 @@ class CodeGenerator: else: raise CodeError("invalid lvalue memmapped datatype", str(lv)) - def generate_assign_char_to_memory(self, lv: ParseResult.MemMappedValue, char_str: str) -> None: + def generate_assign_char_to_memory(self, lv: MemMappedValue, char_str: str) -> None: # Memory = Character with self.preserving_registers({'A'}, loads_a_within=True): self.p("\t\tlda #" + char_str) @@ -1795,7 +1796,7 @@ class CodeGenerator: else: raise CodeError("invalid lvalue type " + str(sym)) - def generate_assign_integer_to_reg(self, l_register: str, rvalue: ParseResult.IntegerValue) -> None: + def generate_assign_integer_to_reg(self, l_register: str, rvalue: IntegerValue) -> None: r_str = rvalue.name if rvalue.name else "${:x}".format(rvalue.value) if l_register in ('A', 'X', 'Y'): self.p("\t\tld{:s} #{:s}".format(l_register.lower(), r_str)) @@ -1817,13 +1818,13 @@ class CodeGenerator: else: raise CodeError("invalid register in immediate integer assignment", l_register, rvalue.value) - def generate_assign_char_to_reg(self, lv: ParseResult.RegisterValue, char_str: str) -> None: + def generate_assign_char_to_reg(self, lv: RegisterValue, char_str: str) -> None: # Register = Char (string of length 1) if lv.register not in ('A', 'X', 'Y'): raise CodeError("invalid register for char assignment", lv.register) self.p("\t\tld{:s} #{:s}".format(lv.register.lower(), char_str)) - def generate_assign_string_to_reg(self, lv: ParseResult.RegisterValue, rvalue: ParseResult.StringValue) -> None: + def generate_assign_string_to_reg(self, lv: RegisterValue, rvalue: StringValue) -> None: if lv.register not in ("AX", "AY", "XY"): raise CodeError("need register pair AX, AY or XY for string address assignment", lv.register) if rvalue.name: @@ -1832,7 +1833,7 @@ class CodeGenerator: else: raise CodeError("cannot assign immediate string, it should be a string variable") - def generate_assign_string_to_memory(self, lv: ParseResult.MemMappedValue, rvalue: ParseResult.StringValue) -> None: + def generate_assign_string_to_memory(self, lv: MemMappedValue, rvalue: StringValue) -> None: if lv.datatype != DataType.WORD: raise CodeError("need word memory type for string address assignment") if rvalue.name: diff --git a/il65/astparse.py b/il65/exprparse.py similarity index 100% rename from il65/astparse.py rename to il65/exprparse.py diff --git a/il65/parse.py b/il65/parse.py index fea99738c..94bbabfd9 100644 --- a/il65/parse.py +++ b/il65/parse.py @@ -13,11 +13,12 @@ import sys import shutil import enum from collections import defaultdict -from typing import Set, List, Tuple, Optional, Any, Dict, Union, Generator -from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\ +from typing import Set, List, Tuple, Optional, Dict, Union, Generator +from .exprparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\ parse_expr_as_string, parse_arguments, parse_expr_as_comparison +from .astdefs import * from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \ - Zeropage, check_value_in_range, char_to_bytevalue, \ + Zeropage, char_to_bytevalue, \ PrimitiveType, VariableDef, ConstantDef, SymbolError, STRING_DATATYPES, \ REGISTER_SYMBOLS, REGISTER_WORDS, REGISTER_BYTES, REGISTER_SBITS, RESERVED_NAMES @@ -35,17 +36,17 @@ class ParseResult: self.clobberzp = False self.restorezp = False self.start_address = 0 - self.blocks = [] # type: List['ParseResult.Block'] + self.blocks = [] # type: List[Block] self.subroutine_usage = defaultdict(set) # type: Dict[Tuple[str, str], Set[str]] self.zeropage = Zeropage() - def all_blocks(self) -> Generator['ParseResult.Block', None, None]: + def all_blocks(self) -> Generator[Block, None, None]: for block in self.blocks: yield block for sub in block.symbols.iter_subroutines(True): yield sub.sub_block - def add_block(self, block: 'ParseResult.Block', position: Optional[int]=None) -> None: + def add_block(self, block: Block, position: Optional[int]=None) -> None: if position is not None: self.blocks.insert(position, block) else: @@ -61,7 +62,7 @@ class ParseResult: if block.name != "
": self.blocks.append(block) - def find_block(self, name: str) -> 'Block': + def find_block(self, name: str) -> Block: for block in self.blocks: if block.name == name: return block @@ -70,507 +71,6 @@ class ParseResult: def sub_used_by(self, sub: SubroutineDef, sourceref: SourceRef) -> None: self.subroutine_usage[(sub.blockname, sub.name)].add(str(sourceref)) - class _AstNode: - def __init__(self, sourceref: SourceRef) -> None: - self.sourceref = sourceref.copy() - - @property - def lineref(self) -> str: - return "src l. " + str(self.sourceref.line) - - class Block(_AstNode): - _unnamed_block_labels = {} # type: Dict[ParseResult.Block, str] - - def __init__(self, name: str, sourceref: SourceRef, parent_scope: SymbolTable) -> None: - super().__init__(sourceref) - self.address = 0 - self.name = name - self.statements = [] # type: List[ParseResult._AstNode] - self.symbols = SymbolTable(name, parent_scope, self) - - @property - def ignore(self) -> bool: - return not self.name and not self.address - - @property - def label_names(self) -> Set[str]: - return {symbol.name for symbol in self.symbols.iter_labels()} - - @property - def label(self) -> str: - if self.name: - return self.name - if self in self._unnamed_block_labels: - return self._unnamed_block_labels[self] - label = "il65_block_{:d}".format(len(self._unnamed_block_labels)) - self._unnamed_block_labels[self] = label - return label - - def lookup(self, dottedname: str) -> Tuple[Optional['ParseResult.Block'], Optional[Union[SymbolDefinition, SymbolTable]]]: - # Searches a name in the current block or globally, if the name is scoped (=contains a '.'). - # Does NOT utilize a symbol table from a preprocessing parse phase, only looks in the current. - try: - scope, result = self.symbols.lookup(dottedname) - return scope.owning_block, result - except (SymbolError, LookupError): - return None, None - - def all_statements(self) -> Generator[Tuple['ParseResult.Block', Optional[SubroutineDef], 'ParseResult._AstNode'], None, None]: - for stmt in self.statements: - yield self, None, stmt - for sub in self.symbols.iter_subroutines(True): - for stmt in sub.sub_block.statements: - yield sub.sub_block, sub, stmt - - class Value(_AstNode): - def __init__(self, datatype: DataType, sourceref: SourceRef, name: str=None, constant: bool=False) -> None: - super().__init__(sourceref) - self.datatype = datatype - self.name = name - self.constant = constant - - def assignable_from(self, other: 'ParseResult.Value') -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - return False, "incompatible value for assignment" - - class IndirectValue(Value): - # only constant integers, memmapped and register values are wrapped in this. - def __init__(self, value: 'ParseResult.Value', type_modifier: DataType, sourceref: SourceRef) -> None: - assert type_modifier - super().__init__(type_modifier, sourceref, value.name, False) - self.value = value - - def __str__(self): - return "".format(self.value, self.datatype, self.name) - - def __hash__(self): - return hash((self.datatype, self.name, self.value)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.IndirectValue): - return NotImplemented - elif self is other: - return True - else: - vvo = getattr(other.value, "value", getattr(other.value, "address", None)) - vvs = getattr(self.value, "value", getattr(self.value, "address", None)) - return (other.datatype, other.name, other.value.name, other.value.datatype, other.value.constant, vvo) ==\ - (self.datatype, self.name, self.value.name, self.value.datatype, self.value.constant, vvs) - - def assignable_from(self, other: 'ParseResult.Value') -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - if self.datatype == DataType.BYTE: - if other.datatype == DataType.BYTE: - return True, "" - if self.datatype == DataType.WORD: - if other.datatype in {DataType.BYTE, DataType.WORD} | STRING_DATATYPES: - return True, "" - if self.datatype == DataType.FLOAT: - if other.datatype in {DataType.BYTE, DataType.WORD, DataType.FLOAT}: - return True, "" - if isinstance(other, (ParseResult.IntegerValue, ParseResult.FloatValue, ParseResult.StringValue)): - rangefault = check_value_in_range(self.datatype, "", 1, other.value) - if rangefault: - return False, rangefault - return True, "" - return False, "incompatible value for indirect assignment (need byte, word, float or string)" - - class IntegerValue(Value): - def __init__(self, value: Optional[int], sourceref: SourceRef, *, datatype: DataType=None, name: str=None) -> None: - if type(value) is int: - if datatype is None: - if 0 <= value < 0x100: - datatype = DataType.BYTE - elif value < 0x10000: - datatype = DataType.WORD - else: - raise OverflowError("value too big: ${:x}".format(value)) - else: - faultreason = check_value_in_range(datatype, "", 1, value) - if faultreason: - raise OverflowError(faultreason) - super().__init__(datatype, sourceref, name, True) - self.value = value - elif value is None: - if not name: - raise ValueError("when integer value is not given, the name symbol should be speicified") - super().__init__(datatype, sourceref, name, True) - self.value = None - else: - raise TypeError("invalid data type") - - def __hash__(self): - return hash((self.datatype, self.value, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.IntegerValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) - - def __str__(self): - return "".format(self.value, self.name) - - class FloatValue(Value): - def __init__(self, value: float, sourceref: SourceRef, name: str=None) -> None: - if type(value) is float: - super().__init__(DataType.FLOAT, sourceref, name, True) - self.value = value - else: - raise TypeError("invalid data type") - - def __hash__(self): - return hash((self.datatype, self.value, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.FloatValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) - - def __str__(self): - return "".format(self.value, self.name) - - class StringValue(Value): - def __init__(self, value: str, sourceref: SourceRef, name: str=None, constant: bool=False) -> None: - super().__init__(DataType.STRING, sourceref, name, constant) - self.value = value - - def __hash__(self): - return hash((self.datatype, self.value, self.name, self.constant)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.StringValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name, other.constant) == (self.datatype, self.value, self.name, self.constant) - - def __str__(self): - return "".format(self.value, self.name, self.constant) - - class RegisterValue(Value): - def __init__(self, register: str, datatype: DataType, sourceref: SourceRef, name: str=None) -> None: - assert datatype in (DataType.BYTE, DataType.WORD) - assert register in REGISTER_SYMBOLS - super().__init__(datatype, sourceref, name, False) - self.register = register - - def __hash__(self): - return hash((self.datatype, self.register, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.RegisterValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.register, other.name) == (self.datatype, self.register, self.name) - - def __str__(self): - return "".format(self.register, self.datatype, self.name) - - def assignable_from(self, other: 'ParseResult.Value') -> Tuple[bool, str]: - if isinstance(other, ParseResult.IndirectValue): - if self.datatype == DataType.BYTE: - if other.datatype == DataType.BYTE: - return True, "" - return False, "(unsigned) byte required" - if self.datatype == DataType.WORD: - if other.datatype in (DataType.BYTE, DataType.WORD): - return True, "" - return False, "(unsigned) byte required" - return False, "incompatible indirect value for register assignment" - if self.register in ("SC", "SI"): - if isinstance(other, ParseResult.IntegerValue) and other.value in (0, 1): - return True, "" - return False, "can only assign an integer constant value of 0 or 1 to SC and SI" - if self.constant: - return False, "cannot assign to a constant" - if isinstance(other, ParseResult.RegisterValue): - if other.register in {"SI", "SC", "SZ"}: - return False, "cannot explicitly assign from a status bit register alias" - if len(self.register) < len(other.register): - return False, "register size mismatch" - if isinstance(other, ParseResult.StringValue) and self.register in REGISTER_BYTES | REGISTER_SBITS: - return False, "string address requires 16 bits combined register" - if isinstance(other, ParseResult.IntegerValue): - if other.value is not None: - range_error = check_value_in_range(self.datatype, self.register, 1, other.value) - if range_error: - return False, range_error - return True, "" - if self.datatype == DataType.WORD: - return True, "" - return False, "cannot assign address to single register" - if isinstance(other, ParseResult.FloatValue): - range_error = check_value_in_range(self.datatype, self.register, 1, other.value) - if range_error: - return False, range_error - return True, "" - if self.datatype == DataType.BYTE: - if other.datatype != DataType.BYTE: - return False, "(unsigned) byte required" - return True, "" - if self.datatype == DataType.WORD: - if other.datatype in (DataType.BYTE, DataType.WORD) or other.datatype in STRING_DATATYPES: - return True, "" - return False, "(unsigned) byte, word or string required" - return False, "incompatible value for register assignment" - - class MemMappedValue(Value): - def __init__(self, address: Optional[int], datatype: DataType, length: int, - sourceref: SourceRef, name: str=None, constant: bool=False) -> None: - super().__init__(datatype, sourceref, name, constant) - self.address = address - self.length = length - assert address is None or type(address) is int - - def __hash__(self): - return hash((self.datatype, self.address, self.length, self.name, self.constant)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ParseResult.MemMappedValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.address, other.length, other.name, other.constant) ==\ - (self.datatype, self.address, self.length, self.name, self.constant) - - def __str__(self): - addr = "" if self.address is None else "${:04x}".format(self.address) - return ""\ - .format(addr, self.datatype, self.length, self.name, self.constant) - - def assignable_from(self, other: 'ParseResult.Value') -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - if isinstance(other, ParseResult.IndirectValue): - return False, "can not yet assign memory mapped value from indirect value" # @todo indirect v assign - if self.datatype == DataType.BYTE: - if isinstance(other, (ParseResult.IntegerValue, ParseResult.RegisterValue, ParseResult.MemMappedValue)): - if other.datatype == DataType.BYTE: - return True, "" - return False, "(unsigned) byte required" - elif isinstance(other, ParseResult.FloatValue): - range_error = check_value_in_range(self.datatype, "", 1, other.value) - if range_error: - return False, range_error - return True, "" - else: - return False, "(unsigned) byte required" - elif self.datatype in (DataType.WORD, DataType.FLOAT): - if isinstance(other, (ParseResult.IntegerValue, ParseResult.FloatValue)): - range_error = check_value_in_range(self.datatype, "", 1, other.value) - if range_error: - return False, range_error - return True, "" - elif isinstance(other, (ParseResult.RegisterValue, ParseResult.MemMappedValue)): - if other.datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): - return True, "" - else: - return False, "byte or word or float required" - elif isinstance(other, ParseResult.StringValue): - if self.datatype == DataType.WORD: - return True, "" - return False, "string address requires 16 bits (a word)" - if self.datatype == DataType.BYTE: - return False, "(unsigned) byte required" - if self.datatype == DataType.WORD: - return False, "(unsigned) word required" - return False, "incompatible value for assignment" - - class Comment(_AstNode): - def __init__(self, text: str, sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.text = text - - class Label(_AstNode): - def __init__(self, name: str, sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.name = name - - class AssignmentStmt(_AstNode): - def __init__(self, leftvalues: List['ParseResult.Value'], right: 'ParseResult.Value', sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.leftvalues = leftvalues - self.right = right - - def __str__(self): - return "".format(str(self.right), ",".join(str(lv) for lv in self.leftvalues)) - - _immediate_string_vars = {} # type: Dict[str, Tuple[str, str]] - - def desugar_immediate_string(self, parser: 'Parser') -> None: - if self.right.name or not isinstance(self.right, ParseResult.StringValue): - return - if self.right.value in self._immediate_string_vars: - blockname, stringvar_name = self._immediate_string_vars[self.right.value] - if blockname: - self.right.name = blockname + '.' + stringvar_name - else: - self.right.name = stringvar_name - else: - cur_block = parser.cur_block - stringvar_name = "il65_str_{:d}".format(id(self)) - value = self.right.value - cur_block.symbols.define_variable(stringvar_name, cur_block.sourceref, DataType.STRING, value=value) - self.right.name = stringvar_name - self._immediate_string_vars[self.right.value] = (cur_block.name, stringvar_name) - - def remove_identity_lvalues(self) -> None: - for lv in self.leftvalues: - if lv == self.right: - print("{}: removed identity assignment".format(self.sourceref)) - remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right] - self.leftvalues = remaining_leftvalues - - def is_identity(self) -> bool: - return all(lv == self.right for lv in self.leftvalues) - - class AugmentedAssignmentStmt(AssignmentStmt): - SUPPORTED_OPERATORS = {"+=", "-=", "&=", "|=", "^=", ">>=", "<<="} - # full set: {"+=", "-=", "*=", "/=", "%=", "//=", "**=", "&=", "|=", "^=", ">>=", "<<="} - - def __init__(self, left: 'ParseResult.Value', operator: str, right: 'ParseResult.Value', sourceref: SourceRef) -> None: - assert operator in self.SUPPORTED_OPERATORS - super().__init__([left], right, sourceref) - self.operator = operator - - def __str__(self): - return "".format(str(self.leftvalues[0]), self.operator, str(self.right)) - - class ReturnStmt(_AstNode): - def __init__(self, sourceref: SourceRef, a: Optional['ParseResult.Value']=None, - x: Optional['ParseResult.Value']=None, - y: Optional['ParseResult.Value']=None) -> None: - super().__init__(sourceref) - self.a = a - self.x = x - self.y = y - - class InplaceIncrStmt(_AstNode): - def __init__(self, what: 'ParseResult.Value', howmuch: Union[int, float], sourceref: SourceRef) -> None: - super().__init__(sourceref) - assert howmuch > 0 - self.what = what - self.howmuch = howmuch - - class InplaceDecrStmt(_AstNode): - def __init__(self, what: 'ParseResult.Value', howmuch: Union[int, float], sourceref: SourceRef) -> None: - super().__init__(sourceref) - assert howmuch > 0 - self.what = what - self.howmuch = howmuch - - class CallStmt(_AstNode): - def __init__(self, sourceref: SourceRef, target: Optional['ParseResult.Value']=None, *, - address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None, - outputs: List[Tuple[str, 'ParseResult.Value']]=None, is_goto: bool=False, - preserve_regs: bool=True, condition: 'ParseResult.IfCondition'=None) -> None: - if not is_goto: - assert condition is None - super().__init__(sourceref) - self.target = target - self.address = address - self.arguments = arguments - self.outputvars = outputs - self.is_goto = is_goto - self.condition = condition - self.preserve_regs = preserve_regs - self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt] - self.desugared_output_assignments = [] # type: List[ParseResult.AssignmentStmt] - - def desugar_call_arguments_and_outputs(self, parser: 'Parser') -> None: - self.desugared_call_arguments.clear() - self.desugared_output_assignments.clear() - for name, value in self.arguments or []: - assert name is not None, "all call arguments should have a name or be matched on a named parameter" - assignment = parser.parse_assignment(name, value) - if assignment.leftvalues[0].datatype != DataType.BYTE: - if isinstance(assignment.right, ParseResult.IntegerValue) and assignment.right.constant: - # a call that doesn't expect a BYTE argument but gets one, converted from a 1-byte string most likely - if value.startswith("'") and value.endswith("'"): - parser.print_warning("possible problematic string to byte conversion (use a .text var instead?)") - if not assignment.is_identity(): - assignment.sourceref = self.sourceref.copy() # @todo why set this? - self.desugared_call_arguments.append(assignment) - if all(not isinstance(v, ParseResult.RegisterValue) for r, v in self.outputvars or []): - # if none of the output variables are registers, we can simply generate the assignments without issues - for register, value in self.outputvars or []: - rvalue = parser.parse_expression(register) - assignment = ParseResult.AssignmentStmt([value], rvalue, self.sourceref) - self.desugared_output_assignments.append(assignment) - else: - result_reg_mapping = [(register, value.register, value) for register, value in self.outputvars or [] - if isinstance(value, ParseResult.RegisterValue)] - if any(r[0] != r[1] for r in result_reg_mapping): - # not all result parameter registers line up with the correct order of registers in the statement, - # reshuffling call results is not supported yet. - raise parser.PError("result registers and/or their ordering is not the same as in the " - "subroutine definition, this isn't supported yet") - else: - # no register alignment issues, just generate the assignments - # note: do not remove the identity assignment here or the output register handling generates buggy code - for register, value in self.outputvars or []: - rvalue = parser.parse_expression(register) - assignment = ParseResult.AssignmentStmt([value], rvalue, self.sourceref) - self.desugared_output_assignments.append(assignment) - - class InlineAsm(_AstNode): - def __init__(self, asmlines: List[str], sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.asmlines = asmlines - - class IfCondition(_AstNode): - SWAPPED_OPERATOR = {"==": "==", - "!=": "!=", - "<=": ">=", - ">=": "<=", - "<": ">", - ">": "<"} - IF_STATUSES = {"cc", "cs", "vc", "vs", "eq", "ne", "true", "not", "zero", "pos", "neg", "lt", "gt", "le", "ge"} - - def __init__(self, ifstatus: str, leftvalue: Optional['ParseResult.Value'], - operator: str, rightvalue: Optional['ParseResult.Value'], sourceref: SourceRef) -> None: - assert ifstatus in self.IF_STATUSES - assert operator in (None, "") or operator in self.SWAPPED_OPERATOR - if operator: - assert ifstatus in ("true", "not", "zero") - super().__init__(sourceref) - self.ifstatus = ifstatus - self.lvalue = leftvalue - self.comparison_op = operator - self.rvalue = rightvalue - - def __str__(self): - return "".format(self.ifstatus, self.lvalue, self.comparison_op, self.rvalue) - - def make_if_true(self) -> bool: - # makes a condition of the form if_not a < b into: if a > b (gets rid of the not) - # returns whether the change was made or not - if self.ifstatus == "not" and self.comparison_op: - self.ifstatus = "true" - self.comparison_op = self.SWAPPED_OPERATOR[self.comparison_op] - return True - return False - - def swap(self) -> Tuple['ParseResult.Value', str, 'ParseResult.Value']: - self.lvalue, self.comparison_op, self.rvalue = self.rvalue, self.SWAPPED_OPERATOR[self.comparison_op], self.lvalue - return self.lvalue, self.comparison_op, self.rvalue - - class BreakpointStmt(_AstNode): - def __init__(self, sourceref: SourceRef) -> None: - super().__init__(sourceref) - class Parser: def __init__(self, filename: str, outputdir: str, existing_imports: Set[str], parsing_import: bool = False, @@ -587,7 +87,7 @@ class Parser: self.outputdir = outputdir self.parsing_import = parsing_import # are we parsing a import file? self._cur_lineidx = -1 # used to efficiently go to next/previous line in source - self.cur_block = None # type: ParseResult.Block + self.cur_block = None # type: Block self.root_scope = SymbolTable("", None, None) self.root_scope.set_zeropage(self.result.zeropage) self.ppsymbols = ppsymbols # symboltable from preprocess phase @@ -663,13 +163,13 @@ class Parser: while True: line = self.next_line().lstrip() if line.startswith(';'): - self.cur_block.statements.append(ParseResult.Comment(line, self.sourceref)) + self.cur_block.statements.append(Comment(line, self.sourceref)) continue self.prev_line() break def _parse_1(self) -> None: - self.cur_block = ParseResult.Block("
", self.sourceref, self.root_scope) + self.cur_block = Block("
", self.sourceref, self.root_scope) self.result.add_block(self.cur_block) self.parse_header() if not self.parsing_import: @@ -704,14 +204,14 @@ class Parser: if not main_found: raise self.PError("a block 'main' should be defined and contain the program's entry point label 'start'") - def _check_return_statement(self, block: ParseResult.Block, message: str) -> None: + def _check_return_statement(self, block: Block, message: str) -> None: # find last statement that isn't a comment for stmt in reversed(block.statements): - if isinstance(stmt, ParseResult.Comment): + if isinstance(stmt, Comment): continue - if isinstance(stmt, ParseResult.ReturnStmt) or isinstance(stmt, ParseResult.CallStmt) and stmt.is_goto: + if isinstance(stmt, ReturnStmt) or isinstance(stmt, CallStmt) and stmt.is_goto: return - if isinstance(stmt, ParseResult.InlineAsm): + if isinstance(stmt, InlineAsm): # check that the last asm line is a jmp or a rts for asmline in reversed(stmt.asmlines): if asmline.lstrip().startswith(';'): @@ -731,27 +231,63 @@ class Parser: self.sourceref.line = -1 self.sourceref.column = 0 - def desugar_immediate_strings(stmt: ParseResult._AstNode) -> None: - if isinstance(stmt, ParseResult.CallStmt): + def desugar_immediate_strings(stmt: _AstNode, containing_block: Block) -> None: + if isinstance(stmt, CallStmt): for s in stmt.desugared_call_arguments: self.sourceref = s.sourceref.copy() - s.desugar_immediate_string(self) + s.desugar_immediate_string(containing_block) for s in stmt.desugared_output_assignments: self.sourceref = s.sourceref.copy() - s.desugar_immediate_string(self) - if isinstance(stmt, ParseResult.AssignmentStmt): + s.desugar_immediate_string(containing_block) + if isinstance(stmt, AssignmentStmt): self.sourceref = stmt.sourceref.copy() - stmt.desugar_immediate_string(self) + stmt.desugar_immediate_string(containing_block) for block in self.result.blocks: self.cur_block = block self.sourceref = block.sourceref.copy() self.sourceref.column = 0 - for block, sub, stmt in block.all_statements(): - if isinstance(stmt, ParseResult.CallStmt): + for _, sub, stmt in block.all_statements(): + if isinstance(stmt, CallStmt): self.sourceref = stmt.sourceref.copy() - stmt.desugar_call_arguments_and_outputs(self) - desugar_immediate_strings(stmt) + self.desugar_call_arguments_and_outputs(stmt) + desugar_immediate_strings(stmt, self.cur_block) + + def desugar_call_arguments_and_outputs(self, stmt: CallStmt) -> None: + stmt.desugared_call_arguments.clear() + stmt.desugared_output_assignments.clear() + for name, value in stmt.arguments or []: + assert name is not None, "all call arguments should have a name or be matched on a named parameter" + assignment = self.parse_assignment(name, value) + if assignment.leftvalues[0].datatype != DataType.BYTE: + if isinstance(assignment.right, IntegerValue) and assignment.right.constant: + # a call that doesn't expect a BYTE argument but gets one, converted from a 1-byte string most likely + if value.startswith("'") and value.endswith("'"): + self.print_warning("possible problematic string to byte conversion (use a .text var instead?)") + if not assignment.is_identity(): + assignment.sourceref = stmt.sourceref.copy() # @todo why set this? + stmt.desugared_call_arguments.append(assignment) + if all(not isinstance(v, RegisterValue) for r, v in stmt.outputvars or []): + # if none of the output variables are registers, we can simply generate the assignments without issues + for register, value in stmt.outputvars or []: + rvalue = self.parse_expression(register) + assignment = AssignmentStmt([value], rvalue, stmt.sourceref) + stmt.desugared_output_assignments.append(assignment) + else: + result_reg_mapping = [(register, value.register, value) for register, value in stmt.outputvars or [] + if isinstance(value, RegisterValue)] + if any(r[0] != r[1] for r in result_reg_mapping): + # not all result parameter registers line up with the correct order of registers in the statement, + # reshuffling call results is not supported yet. + raise self.PError("result registers and/or their ordering is not the same as in the " + "subroutine definition, this isn't supported yet") + else: + # no register alignment issues, just generate the assignments + # note: do not remove the identity assignment here or the output register handling generates buggy code + for register, value in stmt.outputvars or []: + rvalue = self.parse_expression(register) + assignment = AssignmentStmt([value], rvalue, stmt.sourceref) + stmt.desugared_output_assignments.append(assignment) def next_line(self) -> str: self._cur_lineidx += 1 @@ -929,7 +465,7 @@ class Parser: def create_import_parser(self, filename: str, outputdir: str) -> 'Parser': return Parser(filename, outputdir, self.existing_imports, True, ppsymbols=self.ppsymbols, sub_usage=self.result.subroutine_usage) - def parse_block(self) -> Optional[ParseResult.Block]: + def parse_block(self) -> Optional[Block]: # first line contains block header "~ [name] [addr]" followed by a '{' self._parse_comments() line = self.next_line() @@ -938,7 +474,7 @@ class Parser: raise self.PError("expected '~' (block)") block_args = line[1:].split() arg = "" - self.cur_block = ParseResult.Block("", self.sourceref.copy(), self.root_scope) + self.cur_block = Block("", self.sourceref.copy(), self.root_scope) is_zp_block = False while block_args: arg = block_args.pop(0) @@ -952,7 +488,7 @@ class Parser: raise self.PError("duplicate block name '{:s}', original definition at {}".format(arg, orig.sourceref)) self.cur_block = orig # zero page block occurrences are merged else: - self.cur_block = ParseResult.Block(arg, self.sourceref.copy(), self.root_scope) + self.cur_block = Block(arg, self.sourceref.copy(), self.root_scope) try: self.root_scope.define_scope(self.cur_block.symbols, self.cur_block.sourceref) except SymbolError as x: @@ -1033,7 +569,7 @@ class Parser: self.prev_line() self.cur_block.statements.append(self.parse_asm()) elif line == "breakpoint": - self.cur_block.statements.append(ParseResult.BreakpointStmt(self.sourceref)) + self.cur_block.statements.append(BreakpointStmt(self.sourceref)) self.print_warning("breakpoint defined") elif unstripped_line.startswith((" ", "\t")): if is_zp_block: @@ -1055,7 +591,7 @@ class Parser: if labelname in self.cur_block.symbols: raise self.PError("symbol already defined") self.cur_block.symbols.define_label(labelname, self.sourceref) - self.cur_block.statements.append(ParseResult.Label(labelname, self.sourceref)) + self.cur_block.statements.append(Label(labelname, self.sourceref)) if len(label_line) > 1: rest = label_line[1] self.cur_block.statements.append(self.parse_statement(rest)) @@ -1114,7 +650,7 @@ class Parser: if code_decl: address = None # parse the subroutine code lines (until the closing '}') - subroutine_block = ParseResult.Block(self.cur_block.name + "." + name, self.sourceref, self.cur_block.symbols) + subroutine_block = Block(self.cur_block.name + "." + name, self.sourceref, self.cur_block.symbols) current_block = self.cur_block self.cur_block = subroutine_block while True: @@ -1186,7 +722,7 @@ class Parser: datatype, length, matrix_dimensions = self.get_datatype(args[1]) return varname, datatype, length, matrix_dimensions, valuetext - def parse_statement(self, line: str) -> ParseResult._AstNode: + def parse_statement(self, line: str) -> _AstNode: match = re.fullmatch(r"(?Pif(_[a-z]+)?)\s+(?P.+)?goto\s+(?P[\S]+?)\s*(\((?P.*)\))?\s*", line) if match: # conditional goto @@ -1225,11 +761,11 @@ class Parser: elif line.endswith(("++", "--")): incr = line.endswith("++") what = self.parse_expression(line[:-2].rstrip()) - if isinstance(what, ParseResult.IntegerValue): + if isinstance(what, IntegerValue): raise self.PError("cannot in/decrement a constant value") if incr: - return ParseResult.InplaceIncrStmt(what, 1, self.sourceref) - return ParseResult.InplaceDecrStmt(what, 1, self.sourceref) + return InplaceIncrStmt(what, 1, self.sourceref) + return InplaceDecrStmt(what, 1, self.sourceref) else: # perhaps it is an augmented assignment statement match = re.fullmatch(r"(?P\S+)\s*(?P\+=|-=|\*=|/=|%=|//=|\*\*=|&=|\|=|\^=|>>=|<<=)\s*(?P\S.*)", line) @@ -1242,7 +778,7 @@ class Parser: raise self.PError("invalid statement") def parse_call_or_goto(self, targetstr: str, argumentstr: str, outputstr: str, - preserve_regs=True, is_goto=False, condition: ParseResult.IfCondition=None) -> ParseResult.CallStmt: + preserve_regs=True, is_goto=False, condition: IfCondition=None) -> CallStmt: if not is_goto: assert condition is None argumentstr = argumentstr.strip() if argumentstr else "" @@ -1251,7 +787,7 @@ class Parser: outputvars = None if argumentstr: arguments = parse_arguments(argumentstr, self.sourceref) - target = None # type: ParseResult.Value + target = None # type: Value if targetstr[0] == '[' and targetstr[-1] == ']': # indirect call to address in register pair or memory location targetstr, target = self.parse_indirect_value(targetstr, True) @@ -1259,12 +795,12 @@ class Parser: raise self.PError("invalid call target (should contain 16-bit)") else: target = self.parse_expression(targetstr) - if not isinstance(target, (ParseResult.IntegerValue, ParseResult.MemMappedValue, ParseResult.IndirectValue)): + if not isinstance(target, (IntegerValue, MemMappedValue, IndirectValue)): raise self.PError("cannot call that type of symbol") - if isinstance(target, ParseResult.IndirectValue) \ - and not isinstance(target.value, (ParseResult.IntegerValue, ParseResult.RegisterValue, ParseResult.MemMappedValue)): + if isinstance(target, IndirectValue) \ + and not isinstance(target.value, (IntegerValue, RegisterValue, MemMappedValue)): raise self.PError("cannot call that type of indirect symbol") - address = target.address if isinstance(target, ParseResult.MemMappedValue) else None + address = target.address if isinstance(target, MemMappedValue) else None try: _, symbol = self.lookup_with_ppsymbols(targetstr) except ParseError: @@ -1307,7 +843,7 @@ class Parser: # verify that all arguments have gotten a name if any(not a[0] for a in arguments or []): raise self.PError("all call arguments should have a name or be matched on a named parameter") - if isinstance(target, (type(None), ParseResult.Value)): + if isinstance(target, (type(None), Value)): # special case for the C-64 lib's print function, to be able to use it with a single character argument if target.name == "c64scr.print_string" and len(arguments) == 1 and isinstance(arguments[0], str): if arguments[0][1].startswith("'") and arguments[0][1].endswith("'"): @@ -1318,11 +854,11 @@ class Parser: assert len(newsymbol.parameters) == 1 arguments = [(newsymbol.parameters[0][1], arguments[0][1])] if is_goto: - return ParseResult.CallStmt(self.sourceref, target, address=address, - arguments=arguments, outputs=outputvars, is_goto=True, condition=condition) + return CallStmt(self.sourceref, target, address=address, arguments=arguments, + outputs=outputvars, is_goto=True, condition=condition) else: - return ParseResult.CallStmt(self.sourceref, target, address=address, - arguments=arguments, outputs=outputvars, preserve_regs=preserve_regs) + return CallStmt(self.sourceref, target, address=address, arguments=arguments, + outputs=outputvars, preserve_regs=preserve_regs) else: raise TypeError("target should be a Value", target) @@ -1334,7 +870,7 @@ class Parser: return int(text[1:], 2) return int(text) - def parse_assignment(self, *parts) -> ParseResult.AssignmentStmt: + def parse_assignment(self, *parts) -> AssignmentStmt: # parses the assignment of one rvalue to one or more lvalues l_values = [self.parse_expression(p) for p in parts[:-1]] r_value = self.parse_expression(parts[-1]) @@ -1346,16 +882,16 @@ class Parser: raise self.PError("cannot assign {0} to {1}; {2}".format(r_value, lv, reason)) if lv.datatype in (DataType.BYTE, DataType.WORD, DataType.MATRIX): # truncate the rvalue if needed - if isinstance(r_value, ParseResult.FloatValue): + if isinstance(r_value, FloatValue): truncated, value = self.coerce_value(self.sourceref, lv.datatype, r_value.value) if truncated: - r_value = ParseResult.IntegerValue(int(value), self.sourceref, datatype=lv.datatype, name=r_value.name) - return ParseResult.AssignmentStmt(l_values, r_value, self.sourceref) + r_value = IntegerValue(int(value), self.sourceref, datatype=lv.datatype, name=r_value.name) + return AssignmentStmt(l_values, r_value, self.sourceref) def parse_augmented_assignment(self, leftstr: str, operator: str, rightstr: str) \ - -> Union[ParseResult.AssignmentStmt, ParseResult.InplaceDecrStmt, ParseResult.InplaceIncrStmt]: + -> Union[AssignmentStmt, InplaceDecrStmt, InplaceIncrStmt]: # parses an augmented assignment (for instance: value += 3) - if operator not in ParseResult.AugmentedAssignmentStmt.SUPPORTED_OPERATORS: + if operator not in AugmentedAssignmentStmt.SUPPORTED_OPERATORS: raise self.PError("augmented assignment operator '{:s}' not supported".format(operator)) l_value = self.parse_expression(leftstr) r_value = self.parse_expression(rightstr) @@ -1363,28 +899,28 @@ class Parser: raise self.PError("can't have a constant as assignment target, perhaps you wanted indirection [...] instead?") if l_value.datatype in (DataType.BYTE, DataType.WORD, DataType.MATRIX): # truncate the rvalue if needed - if isinstance(r_value, ParseResult.FloatValue): + if isinstance(r_value, FloatValue): truncated, value = self.coerce_value(self.sourceref, l_value.datatype, r_value.value) if truncated: - r_value = ParseResult.IntegerValue(int(value), self.sourceref, datatype=l_value.datatype, name=r_value.name) + r_value = IntegerValue(int(value), self.sourceref, datatype=l_value.datatype, name=r_value.name) if r_value.constant and operator in ("+=", "-="): if operator == "+=": if r_value.value > 0: # type: ignore - return ParseResult.InplaceIncrStmt(l_value, r_value.value, self.sourceref) # type: ignore + return InplaceIncrStmt(l_value, r_value.value, self.sourceref) # type: ignore elif r_value.value < 0: # type: ignore - return ParseResult.InplaceDecrStmt(l_value, -r_value.value, self.sourceref) # type: ignore + return InplaceDecrStmt(l_value, -r_value.value, self.sourceref) # type: ignore else: self.print_warning("incr with zero, ignored") else: if r_value.value > 0: # type: ignore - return ParseResult.InplaceDecrStmt(l_value, r_value.value, self.sourceref) # type: ignore + return InplaceDecrStmt(l_value, r_value.value, self.sourceref) # type: ignore elif r_value.value < 0: # type: ignore - return ParseResult.InplaceIncrStmt(l_value, -r_value.value, self.sourceref) # type: ignore + return InplaceIncrStmt(l_value, -r_value.value, self.sourceref) # type: ignore else: self.print_warning("decr with zero, ignored") - return ParseResult.AugmentedAssignmentStmt(l_value, operator, r_value, self.sourceref) + return AugmentedAssignmentStmt(l_value, operator, r_value, self.sourceref) - def parse_return(self, line: str) -> ParseResult.ReturnStmt: + def parse_return(self, line: str) -> ReturnStmt: parts = line.split(maxsplit=1) if parts[0] != "return": raise self.PError("invalid statement, return expected") @@ -1393,7 +929,7 @@ class Parser: if len(parts) > 1: values = parts[1].split(",") if len(values) == 0: - return ParseResult.ReturnStmt(self.sourceref) + return ReturnStmt(self.sourceref) else: a = self.parse_expression(values[0]) if values[0] else None if len(values) > 1: @@ -1402,9 +938,9 @@ class Parser: y = self.parse_expression(values[2]) if values[2] else None if len(values) > 3: raise self.PError("too many returnvalues") - return ParseResult.ReturnStmt(self.sourceref, a, x, y) + return ReturnStmt(self.sourceref, a, x, y) - def parse_asm(self) -> ParseResult.InlineAsm: + def parse_asm(self) -> InlineAsm: line = self.next_line() aline = line.split() if not len(aline) == 2 or aline[0] != "asm" or aline[1] != "{": @@ -1413,7 +949,7 @@ class Parser: while True: line = self.next_line() if line.strip() == "}": - return ParseResult.InlineAsm(asmlines, self.sourceref) + return InlineAsm(asmlines, self.sourceref) # asm can refer to other symbols as well, track subroutine usage splits = line.split(maxsplit=1) if len(splits) == 2: @@ -1432,7 +968,7 @@ class Parser: self.result.sub_used_by(symbol, self.sourceref) asmlines.append(line) - def parse_asminclude(self, line: str) -> ParseResult.InlineAsm: + def parse_asminclude(self, line: str) -> InlineAsm: aline = line.split() if len(aline) < 2: raise self.PError("invalid asminclude or asmbinary statement") @@ -1454,7 +990,7 @@ class Parser: lines = ['{:s}\t.binclude "{:s}"'.format(scopename, filename)] else: raise self.PError("invalid asminclude statement") - return ParseResult.InlineAsm(lines, self.sourceref) + return InlineAsm(lines, self.sourceref) elif aline[0] == "asmbinary": if len(aline) == 4: offset = parse_expr_as_int(aline[2], None, None, self.sourceref) @@ -1467,11 +1003,11 @@ class Parser: lines = ['\t.binary "{:s}"'.format(filename)] else: raise self.PError("invalid asmbinary statement") - return ParseResult.InlineAsm(lines, self.sourceref) + return InlineAsm(lines, self.sourceref) else: raise self.PError("invalid statement") - def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value: + def parse_expression(self, text: str, is_indirect=False) -> Value: # parse an expression into whatever it is (primitive value, register, memory, register, etc) text = text.strip() if not text: @@ -1481,37 +1017,37 @@ class Parser: raise self.PError("using the address-of something in an indirect value makes no sense") # take the pointer (memory address) from the thing that follows this expression = self.parse_expression(text[1:]) - if isinstance(expression, ParseResult.StringValue): + if isinstance(expression, StringValue): return expression - elif isinstance(expression, ParseResult.MemMappedValue): - return ParseResult.IntegerValue(expression.address, self.sourceref, datatype=DataType.WORD, name=expression.name) + elif isinstance(expression, MemMappedValue): + return IntegerValue(expression.address, self.sourceref, datatype=DataType.WORD, name=expression.name) else: raise self.PError("cannot take the address of this type") elif text[0] in "-.0123456789$%~": number = parse_expr_as_number(text, self.cur_block.symbols, self.ppsymbols, self.sourceref) try: if type(number) is int: - return ParseResult.IntegerValue(int(number), self.sourceref) + return IntegerValue(int(number), self.sourceref) elif type(number) is float: - return ParseResult.FloatValue(number, self.sourceref) + return FloatValue(number, self.sourceref) else: raise TypeError("invalid number type") except (ValueError, OverflowError) as ex: raise self.PError(str(ex)) elif text in REGISTER_WORDS: - return ParseResult.RegisterValue(text, DataType.WORD, self.sourceref) + return RegisterValue(text, DataType.WORD, self.sourceref) elif text in REGISTER_BYTES | REGISTER_SBITS: - return ParseResult.RegisterValue(text, DataType.BYTE, self.sourceref) + return RegisterValue(text, DataType.BYTE, self.sourceref) elif (text.startswith("'") and text.endswith("'")) or (text.startswith('"') and text.endswith('"')): strvalue = parse_expr_as_string(text, self.cur_block.symbols, self.ppsymbols, self.sourceref) if len(strvalue) == 1: petscii_code = char_to_bytevalue(strvalue) - return ParseResult.IntegerValue(petscii_code, self.sourceref) - return ParseResult.StringValue(strvalue, self.sourceref) + return IntegerValue(petscii_code, self.sourceref) + return StringValue(strvalue, self.sourceref) elif text == "true": - return ParseResult.IntegerValue(1, self.sourceref) + return IntegerValue(1, self.sourceref) elif text == "false": - return ParseResult.IntegerValue(0, self.sourceref) + return IntegerValue(0, self.sourceref) elif self.is_identifier(text): symblock, sym = self.lookup_with_ppsymbols(text) if isinstance(sym, (VariableDef, ConstantDef)): @@ -1521,22 +1057,21 @@ class Parser: else: symbolname = "{:s}.{:s}".format(sym.blockname, sym.name) if isinstance(sym, VariableDef) and sym.register: - return ParseResult.RegisterValue(sym.register, sym.type, self.sourceref, name=symbolname) + return RegisterValue(sym.register, sym.type, self.sourceref, name=symbolname) elif sym.type in (DataType.BYTE, DataType.WORD, DataType.FLOAT): if isinstance(sym, ConstantDef): if sym.type == DataType.FLOAT: - return ParseResult.FloatValue(sym.value, self.sourceref, sym.name) # type: ignore + return FloatValue(sym.value, self.sourceref, sym.name) # type: ignore elif sym.type in (DataType.BYTE, DataType.WORD): - return ParseResult.IntegerValue(sym.value, self.sourceref, datatype=sym.type, name=sym.name) # type: ignore + return IntegerValue(sym.value, self.sourceref, datatype=sym.type, name=sym.name) # type: ignore elif sym.type in STRING_DATATYPES: - return ParseResult.StringValue(sym.value, self.sourceref, sym.name, True) # type: ignore + return StringValue(sym.value, self.sourceref, sym.name, True) # type: ignore else: raise TypeError("invalid const type", sym.type) else: - return ParseResult.MemMappedValue(sym.address, sym.type, sym.length, - self.sourceref, name=symbolname, constant=constant) + return MemMappedValue(sym.address, sym.type, sym.length, self.sourceref, name=symbolname, constant=constant) elif sym.type in STRING_DATATYPES: - return ParseResult.StringValue(sym.value, self.sourceref, name=symbolname, constant=constant) # type: ignore + return StringValue(sym.value, self.sourceref, name=symbolname, constant=constant) # type: ignore elif sym.type == DataType.MATRIX: raise self.PError("cannot manipulate matrix directly, use one of the matrix procedures") elif sym.type == DataType.BYTEARRAY or sym.type == DataType.WORDARRAY: @@ -1545,11 +1080,11 @@ class Parser: raise self.PError("invalid symbol type") elif isinstance(sym, LabelDef): name = sym.name if symblock is self.cur_block else sym.blockname + '.' + sym.name - return ParseResult.MemMappedValue(None, DataType.WORD, 1, self.sourceref, name, True) + return MemMappedValue(None, DataType.WORD, 1, self.sourceref, name, True) elif isinstance(sym, SubroutineDef): self.result.sub_used_by(sym, self.sourceref) name = sym.name if symblock is self.cur_block else sym.blockname + '.' + sym.name - return ParseResult.MemMappedValue(sym.address, DataType.WORD, 1, self.sourceref, name, True) + return MemMappedValue(sym.address, DataType.WORD, 1, self.sourceref, name, True) else: raise self.PError("invalid symbol type") elif text.startswith('[') and text.endswith(']'): @@ -1557,7 +1092,7 @@ class Parser: else: raise self.PError("invalid single value '" + text + "'") # @todo understand complex expressions - def parse_indirect_value(self, text: str, allow_mmapped_for_call: bool=False) -> Tuple[str, ParseResult.IndirectValue]: + def parse_indirect_value(self, text: str, allow_mmapped_for_call: bool=False) -> Tuple[str, IndirectValue]: indirect = text[1:-1].strip() indirect2, sep, typestr = indirect.rpartition('.') type_modifier = None @@ -1566,23 +1101,23 @@ class Parser: type_modifier, type_len, _ = self.get_datatype(sep + typestr) indirect = indirect2 expr = self.parse_expression(indirect, True) - if not isinstance(expr, (ParseResult.IntegerValue, ParseResult.MemMappedValue, ParseResult.RegisterValue)): + if not isinstance(expr, (IntegerValue, MemMappedValue, RegisterValue)): raise self.PError("only integers, memmapped vars, and registers can be used in an indirect value") if type_modifier is None: - if isinstance(expr, (ParseResult.RegisterValue, ParseResult.MemMappedValue)): + if isinstance(expr, (RegisterValue, MemMappedValue)): type_modifier = expr.datatype else: type_modifier = DataType.BYTE - if isinstance(expr, ParseResult.IntegerValue): + if isinstance(expr, IntegerValue): if type_modifier not in (DataType.BYTE, DataType.WORD, DataType.FLOAT): raise self.PError("invalid type modifier for the value's datatype") - elif isinstance(expr, ParseResult.MemMappedValue): + elif isinstance(expr, MemMappedValue): if allow_mmapped_for_call: if type_modifier and expr.datatype != type_modifier: raise self.PError("invalid type modifier for the value's datatype, must be " + expr.datatype.name) else: raise self.PError("use variable directly instead of using indirect addressing") - return indirect, ParseResult.IndirectValue(expr, type_modifier, self.sourceref) + return indirect, IndirectValue(expr, type_modifier, self.sourceref) def is_identifier(self, name: str) -> bool: if name.isidentifier(): @@ -1592,7 +1127,7 @@ class Parser: return blockname.isidentifier() and name.isidentifier() return False - def lookup_with_ppsymbols(self, dottedname: str) -> Tuple[ParseResult.Block, Union[SymbolDefinition, SymbolTable]]: + def lookup_with_ppsymbols(self, dottedname: str) -> Tuple[Block, Union[SymbolDefinition, SymbolTable]]: # Tries to find a symbol, if it cannot be located, the symbol table from the preprocess parse phase is consulted as well symblock, sym = self.cur_block.lookup(dottedname) if sym is None and self.ppsymbols: @@ -1665,21 +1200,21 @@ class Parser: result = [sentence[i:j].strip(separators) for i, j in zip(indices, indices[1:])] return list(filter(None, result)) # remove empty strings - def parse_if_condition(self, ifpart: str, conditionpart: str) -> ParseResult.IfCondition: + def parse_if_condition(self, ifpart: str, conditionpart: str) -> IfCondition: if ifpart == "if": ifstatus = "true" else: ifstatus = ifpart[3:] - if ifstatus not in ParseResult.IfCondition.IF_STATUSES: + if ifstatus not in IfCondition.IF_STATUSES: raise self.PError("invalid if form") if conditionpart: if ifstatus not in ("true", "not", "zero"): raise self.PError("can only have if[_true], if_not or if_zero when using a comparison expression") left, operator, right = parse_expr_as_comparison(conditionpart, self.sourceref) leftv = self.parse_expression(left) - if not operator and isinstance(leftv, (ParseResult.IntegerValue, ParseResult.FloatValue, ParseResult.StringValue)): + if not operator and isinstance(leftv, (IntegerValue, FloatValue, StringValue)): raise self.PError("condition is a constant value") - if isinstance(leftv, ParseResult.RegisterValue): + if isinstance(leftv, RegisterValue): if leftv.register in {"SC", "SZ", "SI"}: raise self.PError("cannot use a status bit register explicitly in a condition") if operator: @@ -1688,9 +1223,9 @@ class Parser: rightv = None if leftv == rightv: raise self.PError("left and right values in comparison are identical") - result = ParseResult.IfCondition(ifstatus, leftv, operator, rightv, self.sourceref) + result = IfCondition(ifstatus, leftv, operator, rightv, self.sourceref) else: - result = ParseResult.IfCondition(ifstatus, None, "", None, self.sourceref) + result = IfCondition(ifstatus, None, "", None, self.sourceref) if result.make_if_true(): self.print_warning("if_not condition inverted to if") return result @@ -1718,13 +1253,13 @@ class Optimizer: self.optimize_compare_with_zero(block) return self.parsed - def optimize_compare_with_zero(self, block: ParseResult.Block) -> None: + def optimize_compare_with_zero(self, block: Block) -> None: # a conditional goto that compares a value to zero will be simplified # the comparison operator and rvalue (0) will be removed and the if-status changed accordingly for stmt in block.statements: - if isinstance(stmt, ParseResult.CallStmt): + if isinstance(stmt, CallStmt): cond = stmt.condition - if cond and isinstance(cond.rvalue, (ParseResult.IntegerValue, ParseResult.FloatValue)) and cond.rvalue.value == 0: + if cond and isinstance(cond.rvalue, (IntegerValue, FloatValue)) and cond.rvalue.value == 0: simplified = False if cond.ifstatus in ("true", "ne"): if cond.comparison_op == "==": @@ -1749,12 +1284,12 @@ class Optimizer: if simplified: print("{}: simplified comparison with zero".format(stmt.sourceref)) - def combine_assignments_into_multi(self, block: ParseResult.Block) -> None: + def combine_assignments_into_multi(self, block: Block) -> None: # fold multiple consecutive assignments with the same rvalue into one multi-assignment - statements = [] # type: List[ParseResult._AstNode] + statements = [] # type: List[_AstNode] multi_assign_statement = None for stmt in block.statements: - if isinstance(stmt, ParseResult.AssignmentStmt) and not isinstance(stmt, ParseResult.AugmentedAssignmentStmt): + if isinstance(stmt, AssignmentStmt) and not isinstance(stmt, AugmentedAssignmentStmt): if multi_assign_statement and multi_assign_statement.right == stmt.right: multi_assign_statement.leftvalues.extend(stmt.leftvalues) print("{}: joined with previous line into multi-assign statement".format(stmt.sourceref)) @@ -1771,10 +1306,10 @@ class Optimizer: statements.append(multi_assign_statement) block.statements = statements - def optimize_multiassigns(self, block: ParseResult.Block) -> None: + def optimize_multiassigns(self, block: Block) -> None: # optimize multi-assign statements. for stmt in block.statements: - if isinstance(stmt, ParseResult.AssignmentStmt) and len(stmt.leftvalues) > 1: + if isinstance(stmt, AssignmentStmt) and len(stmt.leftvalues) > 1: # remove duplicates lvalues = list(set(stmt.leftvalues)) if len(lvalues) != len(stmt.leftvalues): @@ -1782,10 +1317,10 @@ class Optimizer: # change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) stmt.leftvalues = list(sorted(lvalues, key=_value_sortkey)) - def remove_identity_assigns(self, block: ParseResult.Block) -> None: + def remove_identity_assigns(self, block: Block) -> None: have_removed_stmts = False for index, stmt in enumerate(list(block.statements)): - if isinstance(stmt, ParseResult.AssignmentStmt): + if isinstance(stmt, AssignmentStmt): stmt.remove_identity_lvalues() if not stmt.leftvalues: print("{}: removed identity assignment statement".format(stmt.sourceref)) @@ -1795,7 +1330,7 @@ class Optimizer: # remove the Nones block.statements = [s for s in block.statements if s is not None] - def remove_unused_subroutines(self, block: ParseResult.Block) -> None: + def remove_unused_subroutines(self, block: Block) -> None: # some symbols are used by the emitted assembly code from the code generator, # and should never be removed or the assembler will fail never_remove = {"c64.GIVUAYF", "c64.FREADUY", "c64.FTOMEMXY"} @@ -1809,14 +1344,14 @@ class Optimizer: print("{}: discarded {:d} unused subroutines from block '{:s}'".format(block.sourceref, len(discarded), block.name)) -def _value_sortkey(value: ParseResult.Value) -> int: - if isinstance(value, ParseResult.RegisterValue): +def _value_sortkey(value: Value) -> int: + if isinstance(value, RegisterValue): num = 0 for char in value.register: num *= 100 num += ord(char) return num - elif isinstance(value, ParseResult.MemMappedValue): + elif isinstance(value, MemMappedValue): if value.address is None: return 99999999 if value.address < 0x100: diff --git a/il65/preprocess.py b/il65/preprocess.py index b9d75c9cd..b64f32ab3 100644 --- a/il65/preprocess.py +++ b/il65/preprocess.py @@ -9,6 +9,7 @@ License: GNU GPL 3.0, see LICENSE from typing import List, Tuple, Set from .parse import Parser, ParseResult, SymbolTable, SymbolDefinition from .symbols import SourceRef +from .astdefs import _AstNode, InlineAsm class PreprocessingParser(Parser): @@ -41,10 +42,10 @@ class PreprocessingParser(Parser): self._parse_1() return self.result - def parse_asminclude(self, line: str) -> ParseResult.InlineAsm: - return ParseResult.InlineAsm([], self.sourceref) + def parse_asminclude(self, line: str) -> InlineAsm: + return InlineAsm([], self.sourceref) - def parse_statement(self, line: str) -> ParseResult._AstNode: + def parse_statement(self, line: str) -> _AstNode: return None # type: ignore def parse_var_def(self, line: str) -> None: @@ -62,7 +63,7 @@ class PreprocessingParser(Parser): def parse_subroutine_def(self, line: str) -> None: super().parse_subroutine_def(line) - def create_import_parser(self, filename: str, outputdir: str) -> 'Parser': + def create_import_parser(self, filename: str, outputdir: str) -> Parser: return PreprocessingParser(filename, parsing_import=True, existing_imports=self.existing_imports) def print_import_progress(self, message: str, *args: str) -> None: