From 6511283bb87f71e4468a17c9babcb2db2dc18ff1 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Mon, 1 Jan 2018 23:46:33 +0100 Subject: [PATCH] refactor --- il65/astdefs.py | 511 -------------------------------------- il65/codegen.py | 9 +- il65/main.py | 3 +- il65/optimize.py | 155 ++++++++++++ il65/parse.py | 182 ++------------ il65/preprocess.py | 5 +- il65/symbols.py | 493 +++++++++++++++++++++++++++++++++++- lib/mathlib.ill | 72 ++++++ reference.md | 4 +- testsource/numbergame.ill | 18 +- todo.ill | 44 +++- 11 files changed, 800 insertions(+), 696 deletions(-) delete mode 100644 il65/astdefs.py create mode 100644 il65/optimize.py diff --git a/il65/astdefs.py b/il65/astdefs.py deleted file mode 100644 index 9fd5b1a5d..000000000 --- a/il65/astdefs.py +++ /dev/null @@ -1,511 +0,0 @@ -""" -Programming Language for 6502/6510 microprocessors -These are the Abstract Syntax Tree node classes that form the Parse Tree. - -Written by Irmen de Jong (irmen@razorvine.net) -License: GNU GPL 3.0, see LICENSE -""" - -from .symbols import SourceRef, SymbolTable, SubroutineDef, SymbolDefinition, SymbolError, DataType, \ - STRING_DATATYPES, REGISTER_SYMBOLS, REGISTER_BYTES, REGISTER_SBITS, check_value_in_range -from typing import Dict, Set, List, Tuple, Optional, Union, Generator, Any - -__all__ = ["_AstNode", "Block", "Value", "IndirectValue", "IntegerValue", "FloatValue", "StringValue", "RegisterValue", - "MemMappedValue", "Comment", "Label", "AssignmentStmt", "AugmentedAssignmentStmt", "ReturnStmt", - "InplaceIncrStmt", "InplaceDecrStmt", "IfCondition", "CallStmt", "InlineAsm", "BreakpointStmt"] - - -class _AstNode: - def __init__(self, sourceref: SourceRef) -> None: - self.sourceref = sourceref.copy() - - @property - def lineref(self) -> str: - return "src l. " + str(self.sourceref.line) - - -class Block(_AstNode): - _unnamed_block_labels = {} # type: Dict[Block, str] - - def __init__(self, name: str, sourceref: SourceRef, parent_scope: SymbolTable, preserve_registers: bool=False) -> None: - super().__init__(sourceref) - self.address = 0 - self.name = name - self.statements = [] # type: List[_AstNode] - self.symbols = SymbolTable(name, parent_scope, self) - self.preserve_registers = preserve_registers - - @property - def ignore(self) -> bool: - return not self.name and not self.address - - @property - def label_names(self) -> Set[str]: - return {symbol.name for symbol in self.symbols.iter_labels()} - - @property - def label(self) -> str: - if self.name: - return self.name - if self in self._unnamed_block_labels: - return self._unnamed_block_labels[self] - label = "il65_block_{:d}".format(len(self._unnamed_block_labels)) - self._unnamed_block_labels[self] = label - return label - - def lookup(self, dottedname: str) -> Tuple[Optional['Block'], Optional[Union[SymbolDefinition, SymbolTable]]]: - # Searches a name in the current block or globally, if the name is scoped (=contains a '.'). - # Does NOT utilize a symbol table from a preprocessing parse phase, only looks in the current. - try: - scope, result = self.symbols.lookup(dottedname) - return scope.owning_block, result - except (SymbolError, LookupError): - return None, None - - def all_statements(self) -> Generator[Tuple['Block', Optional[SubroutineDef], _AstNode], None, None]: - for stmt in self.statements: - yield self, None, stmt - for sub in self.symbols.iter_subroutines(True): - for stmt in sub.sub_block.statements: - yield sub.sub_block, sub, stmt - - -class Value(_AstNode): - def __init__(self, datatype: DataType, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: - super().__init__(sourceref) - self.datatype = datatype - self.name = name - self.constant = constant - - def assignable_from(self, other: 'Value') -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - return False, "incompatible value for assignment" - - -class IndirectValue(Value): - # only constant integers, memmapped and register values are wrapped in this. - def __init__(self, value: Value, type_modifier: DataType, sourceref: SourceRef) -> None: - assert type_modifier - super().__init__(type_modifier, sourceref, value.name, False) - self.value = value - - def __str__(self): - return "".format(self.value, self.datatype, self.name) - - def __hash__(self): - return hash((self.datatype, self.name, self.value)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, IndirectValue): - return NotImplemented - elif self is other: - return True - else: - vvo = getattr(other.value, "value", getattr(other.value, "address", None)) - vvs = getattr(self.value, "value", getattr(self.value, "address", None)) - return (other.datatype, other.name, other.value.name, other.value.datatype, other.value.constant, vvo) == \ - (self.datatype, self.name, self.value.name, self.value.datatype, self.value.constant, vvs) - - def assignable_from(self, other: Value) -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - if self.datatype == DataType.BYTE: - if other.datatype == DataType.BYTE: - return True, "" - if self.datatype == DataType.WORD: - if other.datatype in {DataType.BYTE, DataType.WORD} | STRING_DATATYPES: - return True, "" - if self.datatype == DataType.FLOAT: - if other.datatype in {DataType.BYTE, DataType.WORD, DataType.FLOAT}: - return True, "" - if isinstance(other, (IntegerValue, FloatValue, StringValue)): - rangefault = check_value_in_range(self.datatype, "", 1, other.value) - if rangefault: - return False, rangefault - return True, "" - return False, "incompatible value for indirect assignment (need byte, word, float or string)" - - -class IntegerValue(Value): - def __init__(self, value: Optional[int], sourceref: SourceRef, *, datatype: DataType = None, name: str = None) -> None: - if type(value) is int: - if datatype is None: - if 0 <= value < 0x100: - datatype = DataType.BYTE - elif value < 0x10000: - datatype = DataType.WORD - else: - raise OverflowError("value too big: ${:x}".format(value)) - else: - faultreason = check_value_in_range(datatype, "", 1, value) - if faultreason: - raise OverflowError(faultreason) - super().__init__(datatype, sourceref, name, True) - self.value = value - elif value is None: - if not name: - raise ValueError("when integer value is not given, the name symbol should be speicified") - super().__init__(datatype, sourceref, name, True) - self.value = None - else: - raise TypeError("invalid data type") - - def __hash__(self): - return hash((self.datatype, self.value, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, IntegerValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) - - def __str__(self): - return "".format(self.value, self.name) - - def negative(self) -> 'IntegerValue': - return IntegerValue(-self.value, self.sourceref, datatype=self.datatype, name=self.name) - - -class FloatValue(Value): - def __init__(self, value: float, sourceref: SourceRef, name: str = None) -> None: - if type(value) is float: - super().__init__(DataType.FLOAT, sourceref, name, True) - self.value = value - else: - raise TypeError("invalid data type") - - def __hash__(self): - return hash((self.datatype, self.value, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FloatValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) - - def __str__(self): - return "".format(self.value, self.name) - - def negative(self) -> 'FloatValue': - return FloatValue(-self.value, self.sourceref, name=self.name) - - -class StringValue(Value): - def __init__(self, value: str, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: - super().__init__(DataType.STRING, sourceref, name, constant) - self.value = value - - def __hash__(self): - return hash((self.datatype, self.value, self.name, self.constant)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, StringValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.value, other.name, other.constant) == (self.datatype, self.value, self.name, self.constant) - - def __str__(self): - return "".format(self.value, self.name, self.constant) - - -class RegisterValue(Value): - def __init__(self, register: str, datatype: DataType, sourceref: SourceRef, name: str = None) -> None: - assert datatype in (DataType.BYTE, DataType.WORD) - assert register in REGISTER_SYMBOLS - super().__init__(datatype, sourceref, name, False) - self.register = register - - def __hash__(self): - return hash((self.datatype, self.register, self.name)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, RegisterValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.register, other.name) == (self.datatype, self.register, self.name) - - def __str__(self): - return "".format(self.register, self.datatype, self.name) - - def assignable_from(self, other: Value) -> Tuple[bool, str]: - if isinstance(other, IndirectValue): - if self.datatype == DataType.BYTE: - if other.datatype == DataType.BYTE: - return True, "" - return False, "(unsigned) byte required" - if self.datatype == DataType.WORD: - if other.datatype in (DataType.BYTE, DataType.WORD): - return True, "" - return False, "(unsigned) byte required" - return False, "incompatible indirect value for register assignment" - if self.register in ("SC", "SI"): - if isinstance(other, IntegerValue) and other.value in (0, 1): - return True, "" - return False, "can only assign an integer constant value of 0 or 1 to SC and SI" - if self.constant: - return False, "cannot assign to a constant" - if isinstance(other, RegisterValue): - if other.register in {"SI", "SC", "SZ"}: - return False, "cannot explicitly assign from a status bit register alias" - if len(self.register) < len(other.register): - return False, "register size mismatch" - if isinstance(other, StringValue) and self.register in REGISTER_BYTES | REGISTER_SBITS: - return False, "string address requires 16 bits combined register" - if isinstance(other, IntegerValue): - if other.value is not None: - range_error = check_value_in_range(self.datatype, self.register, 1, other.value) - if range_error: - return False, range_error - return True, "" - if self.datatype == DataType.WORD: - return True, "" - return False, "cannot assign address to single register" - if isinstance(other, FloatValue): - range_error = check_value_in_range(self.datatype, self.register, 1, other.value) - if range_error: - return False, range_error - return True, "" - if self.datatype == DataType.BYTE: - if other.datatype != DataType.BYTE: - return False, "(unsigned) byte required" - return True, "" - if self.datatype == DataType.WORD: - if other.datatype in (DataType.BYTE, DataType.WORD) or other.datatype in STRING_DATATYPES: - return True, "" - return False, "(unsigned) byte, word or string required" - return False, "incompatible value for register assignment" - - -class MemMappedValue(Value): - def __init__(self, address: Optional[int], datatype: DataType, length: int, - sourceref: SourceRef, name: str = None, constant: bool = False) -> None: - super().__init__(datatype, sourceref, name, constant) - self.address = address - self.length = length - assert address is None or type(address) is int - - def __hash__(self): - return hash((self.datatype, self.address, self.length, self.name, self.constant)) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, MemMappedValue): - return NotImplemented - elif self is other: - return True - else: - return (other.datatype, other.address, other.length, other.name, other.constant) == \ - (self.datatype, self.address, self.length, self.name, self.constant) - - def __str__(self): - addr = "" if self.address is None else "${:04x}".format(self.address) - return "" \ - .format(addr, self.datatype, self.length, self.name, self.constant) - - def assignable_from(self, other: Value) -> Tuple[bool, str]: - if self.constant: - return False, "cannot assign to a constant" - if isinstance(other, IndirectValue): - if self.datatype == other.datatype: - return True, "" - return False, "data type of value and target are not the same" - if self.datatype == DataType.BYTE: - if isinstance(other, (IntegerValue, RegisterValue, MemMappedValue)): - if other.datatype == DataType.BYTE: - return True, "" - return False, "(unsigned) byte required" - elif isinstance(other, FloatValue): - range_error = check_value_in_range(self.datatype, "", 1, other.value) - if range_error: - return False, range_error - return True, "" - else: - return False, "(unsigned) byte required" - elif self.datatype in (DataType.WORD, DataType.FLOAT): - if isinstance(other, (IntegerValue, FloatValue)): - range_error = check_value_in_range(self.datatype, "", 1, other.value) - if range_error: - return False, range_error - return True, "" - elif isinstance(other, (RegisterValue, MemMappedValue)): - if other.datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): - return True, "" - else: - return False, "byte or word or float required" - elif isinstance(other, StringValue): - if self.datatype == DataType.WORD: - return True, "" - return False, "string address requires 16 bits (a word)" - if self.datatype == DataType.BYTE: - return False, "(unsigned) byte required" - if self.datatype == DataType.WORD: - return False, "(unsigned) word required" - return False, "incompatible value for assignment" - - -class Comment(_AstNode): - def __init__(self, text: str, sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.text = text - - -class Label(_AstNode): - def __init__(self, name: str, sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.name = name - - -class AssignmentStmt(_AstNode): - def __init__(self, leftvalues: List[Value], right: Value, sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.leftvalues = leftvalues - self.right = right - - def __str__(self): - return "".format(str(self.right), ",".join(str(lv) for lv in self.leftvalues)) - - _immediate_string_vars = {} # type: Dict[str, Tuple[str, str]] - - def desugar_immediate_string(self, containing_block: Block) -> None: - if self.right.name or not isinstance(self.right, StringValue): - return - if self.right.value in self._immediate_string_vars: - blockname, stringvar_name = self._immediate_string_vars[self.right.value] - if blockname: - self.right.name = blockname + '.' + stringvar_name - else: - self.right.name = stringvar_name - else: - stringvar_name = "il65_str_{:d}".format(id(self)) - value = self.right.value - containing_block.symbols.define_constant(stringvar_name, self.sourceref, DataType.STRING, value=value) - self.right.name = stringvar_name - self._immediate_string_vars[self.right.value] = (containing_block.name, stringvar_name) - - def remove_identity_lvalues(self) -> None: - for lv in self.leftvalues: - if lv == self.right: - print("{}: removed identity assignment".format(self.sourceref)) - remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right] - self.leftvalues = remaining_leftvalues - - def is_identity(self) -> bool: - return all(lv == self.right for lv in self.leftvalues) - - -class AugmentedAssignmentStmt(AssignmentStmt): - SUPPORTED_OPERATORS = {"+=", "-=", "&=", "|=", "^=", ">>=", "<<="} - - # full set: {"+=", "-=", "*=", "/=", "%=", "//=", "**=", "&=", "|=", "^=", ">>=", "<<="} - - def __init__(self, left: Value, operator: str, right: Value, sourceref: SourceRef) -> None: - assert operator in self.SUPPORTED_OPERATORS - super().__init__([left], right, sourceref) - self.operator = operator - - def __str__(self): - return "".format(str(self.leftvalues[0]), self.operator, str(self.right)) - - -class ReturnStmt(_AstNode): - def __init__(self, sourceref: SourceRef, a: Optional[Value] = None, - x: Optional[Value] = None, - y: Optional[Value] = None) -> None: - super().__init__(sourceref) - self.a = a - self.x = x - self.y = y - - -class InplaceIncrStmt(_AstNode): - def __init__(self, what: Value, value: Union[IntegerValue, FloatValue], sourceref: SourceRef) -> None: - super().__init__(sourceref) - assert value.constant - assert (value.value is None and value.name) or value.value > 0 - self.what = what - self.value = value - - -class InplaceDecrStmt(_AstNode): - def __init__(self, what: Value, value: Union[IntegerValue, FloatValue], sourceref: SourceRef) -> None: - super().__init__(sourceref) - assert value.constant - assert (value.value is None and value.name) or value.value > 0 - self.what = what - self.value = value - - -class IfCondition(_AstNode): - SWAPPED_OPERATOR = {"==": "==", - "!=": "!=", - "<=": ">=", - ">=": "<=", - "<": ">", - ">": "<"} - IF_STATUSES = {"cc", "cs", "vc", "vs", "eq", "ne", "true", "not", "zero", "pos", "neg", "lt", "gt", "le", "ge"} - - def __init__(self, ifstatus: str, leftvalue: Optional[Value], - operator: str, rightvalue: Optional[Value], sourceref: SourceRef) -> None: - assert ifstatus in self.IF_STATUSES - assert operator in (None, "") or operator in self.SWAPPED_OPERATOR - if operator: - assert ifstatus in ("true", "not", "zero") - super().__init__(sourceref) - self.ifstatus = ifstatus - self.lvalue = leftvalue - self.comparison_op = operator - self.rvalue = rightvalue - - def __str__(self): - return "".format(self.ifstatus, self.lvalue, self.comparison_op, self.rvalue) - - def make_if_true(self) -> bool: - # makes a condition of the form if_not a < b into: if a > b (gets rid of the not) - # returns whether the change was made or not - if self.ifstatus == "not" and self.comparison_op: - self.ifstatus = "true" - self.comparison_op = self.SWAPPED_OPERATOR[self.comparison_op] - return True - return False - - def swap(self) -> Tuple[Value, str, Value]: - self.lvalue, self.comparison_op, self.rvalue = self.rvalue, self.SWAPPED_OPERATOR[self.comparison_op], self.lvalue - return self.lvalue, self.comparison_op, self.rvalue - - -class CallStmt(_AstNode): - def __init__(self, sourceref: SourceRef, target: Optional[Value] = None, *, - address: Optional[int] = None, arguments: List[Tuple[str, Any]] = None, - outputs: List[Tuple[str, Value]] = None, is_goto: bool = False, - preserve_regs: Set[str] = None, condition: IfCondition = None) -> None: - if not is_goto: - assert condition is None - super().__init__(sourceref) - self.target = target - self.address = address - self.arguments = arguments - self.outputvars = outputs - self.is_goto = is_goto - self.condition = condition - self.preserve_regs = preserve_regs - self.desugared_call_arguments = [] # type: List[AssignmentStmt] - self.desugared_output_assignments = [] # type: List[AssignmentStmt] - - -class InlineAsm(_AstNode): - def __init__(self, asmlines: List[str], sourceref: SourceRef) -> None: - super().__init__(sourceref) - self.asmlines = asmlines - - -class BreakpointStmt(_AstNode): - def __init__(self, sourceref: SourceRef) -> None: - super().__init__(sourceref) diff --git a/il65/codegen.py b/il65/codegen.py index 76f660e34..ed542775d 100644 --- a/il65/codegen.py +++ b/il65/codegen.py @@ -8,16 +8,13 @@ License: GNU GPL 3.0, see LICENSE import io import re -import math import datetime import subprocess import contextlib from functools import partial -from typing import TextIO, Set, Union, List, Callable +from typing import TextIO, Callable from .parse import ProgramFormat, ParseResult, Parser -from .astdefs import * -from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \ - STRING_DATATYPES, REGISTER_WORDS, REGISTER_BYTES, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE +from .symbols import * class CodeError(Exception): @@ -353,7 +350,7 @@ class CodeGenerator: self.p("{:s}\n\t\t.ptext {:s}".format(vardef.name, self.output_string(str(vardef.value), True))) self.p(".enc 'none'") - def generate_statement(self, stmt: _AstNode) -> None: + def generate_statement(self, stmt: AstNode) -> None: if isinstance(stmt, ReturnStmt): if stmt.a: if isinstance(stmt.a, IntegerValue): diff --git a/il65/main.py b/il65/main.py index 4f9da96b3..f1a8c5c53 100644 --- a/il65/main.py +++ b/il65/main.py @@ -12,7 +12,8 @@ import time import os import argparse import subprocess -from .parse import Parser, Optimizer +from .parse import Parser +from .optimize import Optimizer from .preprocess import PreprocessingParser from .codegen import CodeGenerator, Assembler64Tass diff --git a/il65/optimize.py b/il65/optimize.py new file mode 100644 index 000000000..7802ccaa5 --- /dev/null +++ b/il65/optimize.py @@ -0,0 +1,155 @@ +""" +Programming Language for 6502/6510 microprocessors +This is the code to optimize the parse tree. + +Written by Irmen de Jong (irmen@razorvine.net) +License: GNU GPL 3.0, see LICENSE +""" + +from typing import List +from .parse import ParseResult +from .symbols import Block, AugmentedAssignmentStmt, IntegerValue, FloatValue, AssignmentStmt, CallStmt, \ + Value, MemMappedValue, RegisterValue, AstNode + + +class Optimizer: + def __init__(self, parseresult: ParseResult) -> None: + self.parsed = parseresult + + def optimize(self) -> ParseResult: + print("\noptimizing parse tree") + for block in self.parsed.all_blocks(): + self.remove_augmentedassign_incrdecr_nops(block) + self.remove_identity_assigns(block) + self.combine_assignments_into_multi(block) + self.optimize_multiassigns(block) + self.remove_unused_subroutines(block) + self.optimize_compare_with_zero(block) + return self.parsed + + def remove_augmentedassign_incrdecr_nops(self, block: Block) -> None: + have_removed_stmts = False + for index, stmt in enumerate(list(block.statements)): + if isinstance(stmt, AugmentedAssignmentStmt): + if isinstance(stmt.right, (IntegerValue, FloatValue)): + if stmt.right.value == 0 and stmt.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="): + print("{}: removed statement that has no effect".format(stmt.sourceref)) + have_removed_stmts = True + block.statements[index] = None + if stmt.right.value >= 8 and stmt.operator in ("<<=", ">>="): + print("{}: shifting that many times always results in zero".format(stmt.sourceref)) + new_stmt = AssignmentStmt(stmt.leftvalues, IntegerValue(0, stmt.sourceref), stmt.sourceref) + block.statements[index] = new_stmt + if have_removed_stmts: + # remove the Nones + block.statements = [s for s in block.statements if s is not None] + + def optimize_compare_with_zero(self, block: Block) -> None: + # a conditional goto that compares a value to zero will be simplified + # the comparison operator and rvalue (0) will be removed and the if-status changed accordingly + for stmt in block.statements: + if isinstance(stmt, CallStmt): + cond = stmt.condition + if cond and isinstance(cond.rvalue, (IntegerValue, FloatValue)) and cond.rvalue.value == 0: + simplified = False + if cond.ifstatus in ("true", "ne"): + if cond.comparison_op == "==": + # if_true something == 0 -> if_not something + cond.ifstatus = "not" + cond.comparison_op, cond.rvalue = "", None + simplified = True + elif cond.comparison_op == "!=": + # if_true something != 0 -> if_true something + cond.comparison_op, cond.rvalue = "", None + simplified = True + elif cond.ifstatus in ("not", "eq"): + if cond.comparison_op == "==": + # if_not something == 0 -> if_true something + cond.ifstatus = "true" + cond.comparison_op, cond.rvalue = "", None + simplified = True + elif cond.comparison_op == "!=": + # if_not something != 0 -> if_not something + cond.comparison_op, cond.rvalue = "", None + simplified = True + if simplified: + print("{}: simplified comparison with zero".format(stmt.sourceref)) + + def combine_assignments_into_multi(self, block: Block) -> None: + # fold multiple consecutive assignments with the same rvalue into one multi-assignment + statements = [] # type: List[AstNode] + multi_assign_statement = None + for stmt in block.statements: + if isinstance(stmt, AssignmentStmt) and not isinstance(stmt, AugmentedAssignmentStmt): + if multi_assign_statement and multi_assign_statement.right == stmt.right: + multi_assign_statement.leftvalues.extend(stmt.leftvalues) + print("{}: joined with previous line into multi-assign statement".format(stmt.sourceref)) + else: + if multi_assign_statement: + statements.append(multi_assign_statement) + multi_assign_statement = stmt + else: + if multi_assign_statement: + statements.append(multi_assign_statement) + multi_assign_statement = None + statements.append(stmt) + if multi_assign_statement: + statements.append(multi_assign_statement) + block.statements = statements + + def optimize_multiassigns(self, block: Block) -> None: + # optimize multi-assign statements. + for stmt in block.statements: + if isinstance(stmt, AssignmentStmt) and len(stmt.leftvalues) > 1: + # remove duplicates + lvalues = list(set(stmt.leftvalues)) + if len(lvalues) != len(stmt.leftvalues): + print("{}: removed duplicate assignment targets".format(stmt.sourceref)) + # change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) + stmt.leftvalues = list(sorted(lvalues, key=_value_sortkey)) + + def remove_identity_assigns(self, block: Block) -> None: + have_removed_stmts = False + for index, stmt in enumerate(list(block.statements)): + if isinstance(stmt, AssignmentStmt): + stmt.remove_identity_lvalues() + if not stmt.leftvalues: + print("{}: removed identity assignment statement".format(stmt.sourceref)) + have_removed_stmts = True + block.statements[index] = None + if have_removed_stmts: + # remove the Nones + block.statements = [s for s in block.statements if s is not None] + + def remove_unused_subroutines(self, block: Block) -> None: + # some symbols are used by the emitted assembly code from the code generator, + # and should never be removed or the assembler will fail + never_remove = {"c64.FREADUY", "c64.FTOMEMXY", "c64.FADD", "c64.FSUB", + "c64flt.GIVUAYF", "c64flt.copy_mflt", "c64flt.float_add_one", "c64flt.float_sub_one", + "c64flt.float_add_SW1_to_XY", "c64flt.float_sub_SW1_from_XY"} + discarded = [] + for sub in list(block.symbols.iter_subroutines()): + usages = self.parsed.subroutine_usage[(sub.blockname, sub.name)] + if not usages and sub.blockname + '.' + sub.name not in never_remove: + block.symbols.discard_sub(sub.name) + discarded.append(sub.name) + if discarded: + print("{}: discarded {:d} unused subroutines from block '{:s}'".format(block.sourceref, len(discarded), block.name)) + + +def _value_sortkey(value: Value) -> int: + if isinstance(value, RegisterValue): + num = 0 + for char in value.register: + num *= 100 + num += ord(char) + return num + elif isinstance(value, MemMappedValue): + if value.address is None: + return 99999999 + if value.address < 0x100: + return 10000 + value.address + else: + return 20000 + value.address + else: + return 99999999 diff --git a/il65/parse.py b/il65/parse.py index 9a0ad0bde..c1763f93b 100644 --- a/il65/parse.py +++ b/il65/parse.py @@ -6,21 +6,15 @@ Written by Irmen de Jong (irmen@razorvine.net) License: GNU GPL 3.0, see LICENSE """ -import math import re import os import sys import shutil -import enum from collections import defaultdict from typing import Set, List, Tuple, Optional, Dict, Union, Generator from .exprparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\ parse_expr_as_string, parse_arguments, parse_expr_as_comparison -from .astdefs import * -from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \ - Zeropage, char_to_bytevalue, \ - PrimitiveType, VariableDef, ConstantDef, SymbolError, STRING_DATATYPES, \ - REGISTER_SYMBOLS, REGISTER_WORDS, REGISTER_BYTES, REGISTER_SBITS, RESERVED_NAMES +from .symbols import * class ProgramFormat(enum.Enum): @@ -222,6 +216,8 @@ class Parser: if isinstance(stmt, InlineAsm): # check that the last asm line is a jmp or a rts for asmline in reversed(stmt.asmlines): + if asmline.strip().replace(' ', '').startswith(";returns"): + return if asmline.lstrip().startswith(';'): continue if " rts" in asmline or "\trts" in asmline or " jmp" in asmline or "\tjmp" in asmline: @@ -234,6 +230,7 @@ class Parser: self.print_warning("{:s} doesn't end with a return statement".format(message), block.sourceref) _immediate_floats = {} # type: Dict[float, Tuple[str, str]] + _immediate_string_vars = {} # type: Dict[str, Tuple[str, str]] def _parse_2(self) -> None: # parsing pass 2 (not done during preprocessing!) @@ -241,19 +238,35 @@ class Parser: self.sourceref.line = -1 self.sourceref.column = 0 - def desugar_immediate_strings(stmt: _AstNode, containing_block: Block) -> None: + def imm_string_to_var(stmt: AssignmentStmt, containing_block: Block) -> None: + if stmt.right.name or not isinstance(stmt.right, StringValue): + return + if stmt.right.value in self._immediate_string_vars: + blockname, stringvar_name = self._immediate_string_vars[stmt.right.value] + if blockname: + stmt.right.name = blockname + '.' + stringvar_name + else: + stmt.right.name = stringvar_name + else: + stringvar_name = "il65_str_{:d}".format(id(stmt)) + value = stmt.right.value + containing_block.symbols.define_constant(stringvar_name, stmt.sourceref, DataType.STRING, value=value) + stmt.right.name = stringvar_name + self._immediate_string_vars[stmt.right.value] = (containing_block.name, stringvar_name) + + def desugar_immediate_strings(stmt: AstNode, containing_block: Block) -> None: if isinstance(stmt, CallStmt): for s in stmt.desugared_call_arguments: self.sourceref = s.sourceref.copy() - s.desugar_immediate_string(containing_block) + imm_string_to_var(s, containing_block) for s in stmt.desugared_output_assignments: self.sourceref = s.sourceref.copy() - s.desugar_immediate_string(containing_block) + imm_string_to_var(s, containing_block) if isinstance(stmt, AssignmentStmt): self.sourceref = stmt.sourceref.copy() - stmt.desugar_immediate_string(containing_block) + imm_string_to_var(stmt, containing_block) - def desugar_immediate_floats(stmt: _AstNode, containing_block: Block) -> None: + def desugar_immediate_floats(stmt: AstNode, containing_block: Block) -> None: if isinstance(stmt, (InplaceIncrStmt, InplaceDecrStmt)): howmuch = stmt.value.value if howmuch is None: @@ -815,7 +828,7 @@ class Parser: datatype, length, matrix_dimensions = self.get_datatype(args[1]) return varname, datatype, length, matrix_dimensions, valuetext - def parse_statement(self, line: str) -> _AstNode: + def parse_statement(self, line: str) -> AstNode: match = re.fullmatch(r"(?Pif(_[a-z]+)?)\s+(?P.+)?goto\s+(?P[\S]+?)\s*(\((?P.*)\))?\s*", line) if match: # conditional goto @@ -1337,146 +1350,3 @@ class Parser: return False self.existing_imports.add(filename) return True - - -class Optimizer: - def __init__(self, parseresult: ParseResult) -> None: - self.parsed = parseresult - - def optimize(self) -> ParseResult: - print("\noptimizing parse tree") - for block in self.parsed.all_blocks(): - self.remove_augmentedassign_incrdecr_nops(block) - self.remove_identity_assigns(block) - self.combine_assignments_into_multi(block) - self.optimize_multiassigns(block) - self.remove_unused_subroutines(block) - self.optimize_compare_with_zero(block) - return self.parsed - - def remove_augmentedassign_incrdecr_nops(self, block: Block) -> None: - have_removed_stmts = False - for index, stmt in enumerate(list(block.statements)): - if isinstance(stmt, AugmentedAssignmentStmt): - if isinstance(stmt.right, (IntegerValue, FloatValue)): - if stmt.right.value == 0 and stmt.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="): - print("{}: removed statement that has no effect".format(stmt.sourceref)) - have_removed_stmts = True - block.statements[index] = None - if stmt.right.value >= 8 and stmt.operator in ("<<=", ">>="): - print("{}: shifting that many times always results in zero".format(stmt.sourceref)) - new_stmt = AssignmentStmt(stmt.leftvalues, IntegerValue(0, stmt.sourceref), stmt.sourceref) - block.statements[index] = new_stmt - if have_removed_stmts: - # remove the Nones - block.statements = [s for s in block.statements if s is not None] - - def optimize_compare_with_zero(self, block: Block) -> None: - # a conditional goto that compares a value to zero will be simplified - # the comparison operator and rvalue (0) will be removed and the if-status changed accordingly - for stmt in block.statements: - if isinstance(stmt, CallStmt): - cond = stmt.condition - if cond and isinstance(cond.rvalue, (IntegerValue, FloatValue)) and cond.rvalue.value == 0: - simplified = False - if cond.ifstatus in ("true", "ne"): - if cond.comparison_op == "==": - # if_true something == 0 -> if_not something - cond.ifstatus = "not" - cond.comparison_op, cond.rvalue = "", None - simplified = True - elif cond.comparison_op == "!=": - # if_true something != 0 -> if_true something - cond.comparison_op, cond.rvalue = "", None - simplified = True - elif cond.ifstatus in ("not", "eq"): - if cond.comparison_op == "==": - # if_not something == 0 -> if_true something - cond.ifstatus = "true" - cond.comparison_op, cond.rvalue = "", None - simplified = True - elif cond.comparison_op == "!=": - # if_not something != 0 -> if_not something - cond.comparison_op, cond.rvalue = "", None - simplified = True - if simplified: - print("{}: simplified comparison with zero".format(stmt.sourceref)) - - def combine_assignments_into_multi(self, block: Block) -> None: - # fold multiple consecutive assignments with the same rvalue into one multi-assignment - statements = [] # type: List[_AstNode] - multi_assign_statement = None - for stmt in block.statements: - if isinstance(stmt, AssignmentStmt) and not isinstance(stmt, AugmentedAssignmentStmt): - if multi_assign_statement and multi_assign_statement.right == stmt.right: - multi_assign_statement.leftvalues.extend(stmt.leftvalues) - print("{}: joined with previous line into multi-assign statement".format(stmt.sourceref)) - else: - if multi_assign_statement: - statements.append(multi_assign_statement) - multi_assign_statement = stmt - else: - if multi_assign_statement: - statements.append(multi_assign_statement) - multi_assign_statement = None - statements.append(stmt) - if multi_assign_statement: - statements.append(multi_assign_statement) - block.statements = statements - - def optimize_multiassigns(self, block: Block) -> None: - # optimize multi-assign statements. - for stmt in block.statements: - if isinstance(stmt, AssignmentStmt) and len(stmt.leftvalues) > 1: - # remove duplicates - lvalues = list(set(stmt.leftvalues)) - if len(lvalues) != len(stmt.leftvalues): - print("{}: removed duplicate assignment targets".format(stmt.sourceref)) - # change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) - stmt.leftvalues = list(sorted(lvalues, key=_value_sortkey)) - - def remove_identity_assigns(self, block: Block) -> None: - have_removed_stmts = False - for index, stmt in enumerate(list(block.statements)): - if isinstance(stmt, AssignmentStmt): - stmt.remove_identity_lvalues() - if not stmt.leftvalues: - print("{}: removed identity assignment statement".format(stmt.sourceref)) - have_removed_stmts = True - block.statements[index] = None - if have_removed_stmts: - # remove the Nones - block.statements = [s for s in block.statements if s is not None] - - def remove_unused_subroutines(self, block: Block) -> None: - # some symbols are used by the emitted assembly code from the code generator, - # and should never be removed or the assembler will fail - never_remove = {"c64.FREADUY", "c64.FTOMEMXY", "c64.FADD", "c64.FSUB", - "c64flt.GIVUAYF", "c64flt.copy_mflt", "c64flt.float_add_one", "c64flt.float_sub_one", - "c64flt.float_add_SW1_to_XY", "c64flt.float_sub_SW1_from_XY"} - discarded = [] - for sub in list(block.symbols.iter_subroutines()): - usages = self.parsed.subroutine_usage[(sub.blockname, sub.name)] - if not usages and sub.blockname + '.' + sub.name not in never_remove: - block.symbols.discard_sub(sub.name) - discarded.append(sub.name) - if discarded: - print("{}: discarded {:d} unused subroutines from block '{:s}'".format(block.sourceref, len(discarded), block.name)) - - -def _value_sortkey(value: Value) -> int: - if isinstance(value, RegisterValue): - num = 0 - for char in value.register: - num *= 100 - num += ord(char) - return num - elif isinstance(value, MemMappedValue): - if value.address is None: - return 99999999 - if value.address < 0x100: - return 10000 + value.address - else: - return 20000 + value.address - else: - return 99999999 diff --git a/il65/preprocess.py b/il65/preprocess.py index 4b3ccc50f..02bf97252 100644 --- a/il65/preprocess.py +++ b/il65/preprocess.py @@ -8,8 +8,7 @@ License: GNU GPL 3.0, see LICENSE from typing import List, Tuple, Set from .parse import Parser, ParseResult, SymbolTable, SymbolDefinition -from .symbols import SourceRef -from .astdefs import _AstNode, InlineAsm +from .symbols import SourceRef, AstNode, InlineAsm class PreprocessingParser(Parser): @@ -45,7 +44,7 @@ class PreprocessingParser(Parser): def parse_asminclude(self, line: str) -> InlineAsm: return InlineAsm([], self.sourceref) - def parse_statement(self, line: str) -> _AstNode: + def parse_statement(self, line: str) -> AstNode: return None def parse_var_def(self, line: str) -> None: diff --git a/il65/symbols.py b/il65/symbols.py index d32b29b8c..af9a78898 100644 --- a/il65/symbols.py +++ b/il65/symbols.py @@ -11,7 +11,8 @@ import math import enum import builtins from functools import total_ordering -from typing import Optional, Set, Union, Tuple, Dict, Iterable, Sequence, Any, List +from typing import Optional, Set, Union, Tuple, Dict, Iterable, Sequence, Any, List, Generator + PrimitiveType = Union[int, float, str] @@ -170,10 +171,10 @@ class ConstantDef(SymbolDefinition): class SubroutineDef(SymbolDefinition): def __init__(self, blockname: str, name: str, sourceref: SourceRef, parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str], - address: Optional[int]=None, sub_block: Any=None) -> None: + address: Optional[int]=None, sub_block: 'Block'=None) -> None: super().__init__(blockname, name, sourceref, False) self.address = address - self.sub_block = sub_block # this is a ParseResult.Block + self.sub_block = sub_block self.parameters = parameters self.clobbered_registers = set() # type: Set[str] self.return_registers = [] # type: List[str] # ordered! @@ -269,7 +270,7 @@ class Zeropage: class SymbolTable: - def __init__(self, name: str, parent: Optional['SymbolTable'], owning_block: Any) -> None: + def __init__(self, name: str, parent: Optional['SymbolTable'], owning_block: 'Block') -> None: self.name = name self.symbols = {} # type: Dict[str, Union[SymbolDefinition, SymbolTable]] self.parent = parent @@ -323,7 +324,7 @@ class SymbolTable: if isinstance(scope, SymbolTable): return scope.lookup(nameparts[-1]) elif isinstance(scope, SubroutineDef): - return scope.sub_block.symbols.lookup_with_ppsymbols(nameparts[-1]) + return scope.sub_block.symbols.lookup(nameparts[-1]) else: raise SymbolError("invalid block name '{:s}' in dotted name".format(namepart)) @@ -337,7 +338,7 @@ class SymbolTable: raise SymbolError("can only take address of memory mapped variables") return symbol.address - def as_eval_dict(self, ppsymbols: 'SymbolTable') -> Dict[str, Any]: + def as_eval_dict(self, ppsymbols: 'SymbolTable') -> Dict[str, Any]: # @todo type # return a dictionary suitable to be passed as locals or globals to eval() if self.eval_dict is None: d = EvalSymbolDict(self, ppsymbols) @@ -430,7 +431,7 @@ class SymbolTable: def define_sub(self, name: str, sourceref: SourceRef, parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str], - address: Optional[int], sub_block: Any) -> None: + address: Optional[int], sub_block: 'Block') -> None: self.check_identifier_valid(name, sourceref) self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block) @@ -691,3 +692,481 @@ ascii_to_petscii_trans = str.maketrans({ '▄': 162, # lower half '▒': 230, # raster }) + + +class AstNode: + def __init__(self, sourceref: SourceRef) -> None: + self.sourceref = sourceref.copy() + + @property + def lineref(self) -> str: + return "src l. " + str(self.sourceref.line) + + +class Block(AstNode): + _unnamed_block_labels = {} # type: Dict[Block, str] + + def __init__(self, name: str, sourceref: SourceRef, parent_scope: SymbolTable, preserve_registers: bool=False) -> None: + super().__init__(sourceref) + self.address = 0 + self.name = name + self.statements = [] # type: List[AstNode] + self.symbols = SymbolTable(name, parent_scope, self) + self.preserve_registers = preserve_registers + + @property + def ignore(self) -> bool: + return not self.name and not self.address + + @property + def label_names(self) -> Set[str]: + return {symbol.name for symbol in self.symbols.iter_labels()} + + @property + def label(self) -> str: + if self.name: + return self.name + if self in self._unnamed_block_labels: + return self._unnamed_block_labels[self] + label = "il65_block_{:d}".format(len(self._unnamed_block_labels)) + self._unnamed_block_labels[self] = label + return label + + def lookup(self, dottedname: str) -> Tuple[Optional['Block'], Optional[Union[SymbolDefinition, SymbolTable]]]: + # Searches a name in the current block or globally, if the name is scoped (=contains a '.'). + # Does NOT utilize a symbol table from a preprocessing parse phase, only looks in the current. + try: + scope, result = self.symbols.lookup(dottedname) + return scope.owning_block, result + except (SymbolError, LookupError): + return None, None + + def all_statements(self) -> Generator[Tuple['Block', Optional[SubroutineDef], AstNode], None, None]: + for stmt in self.statements: + yield self, None, stmt + for sub in self.symbols.iter_subroutines(True): + for stmt in sub.sub_block.statements: + yield sub.sub_block, sub, stmt + + +class Value(AstNode): + def __init__(self, datatype: DataType, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(sourceref) + self.datatype = datatype + self.name = name + self.constant = constant + + def assignable_from(self, other: 'Value') -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + return False, "incompatible value for assignment" + + +class IndirectValue(Value): + # only constant integers, memmapped and register values are wrapped in this. + def __init__(self, value: Value, type_modifier: DataType, sourceref: SourceRef) -> None: + assert type_modifier + super().__init__(type_modifier, sourceref, value.name, False) + self.value = value + + def __str__(self): + return "".format(self.value, self.datatype, self.name) + + def __hash__(self): + return hash((self.datatype, self.name, self.value)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, IndirectValue): + return NotImplemented + elif self is other: + return True + else: + vvo = getattr(other.value, "value", getattr(other.value, "address", None)) + vvs = getattr(self.value, "value", getattr(self.value, "address", None)) + return (other.datatype, other.name, other.value.name, other.value.datatype, other.value.constant, vvo) == \ + (self.datatype, self.name, self.value.name, self.value.datatype, self.value.constant, vvs) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + if self.datatype == DataType.BYTE: + if other.datatype == DataType.BYTE: + return True, "" + if self.datatype == DataType.WORD: + if other.datatype in {DataType.BYTE, DataType.WORD} | STRING_DATATYPES: + return True, "" + if self.datatype == DataType.FLOAT: + if other.datatype in {DataType.BYTE, DataType.WORD, DataType.FLOAT}: + return True, "" + if isinstance(other, (IntegerValue, FloatValue, StringValue)): + rangefault = check_value_in_range(self.datatype, "", 1, other.value) + if rangefault: + return False, rangefault + return True, "" + return False, "incompatible value for indirect assignment (need byte, word, float or string)" + + +class IntegerValue(Value): + def __init__(self, value: Optional[int], sourceref: SourceRef, *, datatype: DataType = None, name: str = None) -> None: + if type(value) is int: + if datatype is None: + if 0 <= value < 0x100: + datatype = DataType.BYTE + elif value < 0x10000: + datatype = DataType.WORD + else: + raise OverflowError("value too big: ${:x}".format(value)) + else: + faultreason = check_value_in_range(datatype, "", 1, value) + if faultreason: + raise OverflowError(faultreason) + super().__init__(datatype, sourceref, name, True) + self.value = value + elif value is None: + if not name: + raise ValueError("when integer value is not given, the name symbol should be speicified") + super().__init__(datatype, sourceref, name, True) + self.value = None + else: + raise TypeError("invalid data type") + + def __hash__(self): + return hash((self.datatype, self.value, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, IntegerValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) + + def __str__(self): + return "".format(self.value, self.name) + + def negative(self) -> 'IntegerValue': + return IntegerValue(-self.value, self.sourceref, datatype=self.datatype, name=self.name) + + +class FloatValue(Value): + def __init__(self, value: float, sourceref: SourceRef, name: str = None) -> None: + if type(value) is float: + super().__init__(DataType.FLOAT, sourceref, name, True) + self.value = value + else: + raise TypeError("invalid data type") + + def __hash__(self): + return hash((self.datatype, self.value, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FloatValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name) == (self.datatype, self.value, self.name) + + def __str__(self): + return "".format(self.value, self.name) + + def negative(self) -> 'FloatValue': + return FloatValue(-self.value, self.sourceref, name=self.name) + + +class StringValue(Value): + def __init__(self, value: str, sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(DataType.STRING, sourceref, name, constant) + self.value = value + + def __hash__(self): + return hash((self.datatype, self.value, self.name, self.constant)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, StringValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.value, other.name, other.constant) == (self.datatype, self.value, self.name, self.constant) + + def __str__(self): + return "".format(self.value, self.name, self.constant) + + +class RegisterValue(Value): + def __init__(self, register: str, datatype: DataType, sourceref: SourceRef, name: str = None) -> None: + assert datatype in (DataType.BYTE, DataType.WORD) + assert register in REGISTER_SYMBOLS + super().__init__(datatype, sourceref, name, False) + self.register = register + + def __hash__(self): + return hash((self.datatype, self.register, self.name)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RegisterValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.register, other.name) == (self.datatype, self.register, self.name) + + def __str__(self): + return "".format(self.register, self.datatype, self.name) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if isinstance(other, IndirectValue): + if self.datatype == DataType.BYTE: + if other.datatype == DataType.BYTE: + return True, "" + return False, "(unsigned) byte required" + if self.datatype == DataType.WORD: + if other.datatype in (DataType.BYTE, DataType.WORD): + return True, "" + return False, "(unsigned) byte required" + return False, "incompatible indirect value for register assignment" + if self.register in ("SC", "SI"): + if isinstance(other, IntegerValue) and other.value in (0, 1): + return True, "" + return False, "can only assign an integer constant value of 0 or 1 to SC and SI" + if self.constant: + return False, "cannot assign to a constant" + if isinstance(other, RegisterValue): + if other.register in {"SI", "SC", "SZ"}: + return False, "cannot explicitly assign from a status bit register alias" + if len(self.register) < len(other.register): + return False, "register size mismatch" + if isinstance(other, StringValue) and self.register in REGISTER_BYTES | REGISTER_SBITS: + return False, "string address requires 16 bits combined register" + if isinstance(other, IntegerValue): + if other.value is not None: + range_error = check_value_in_range(self.datatype, self.register, 1, other.value) + if range_error: + return False, range_error + return True, "" + if self.datatype == DataType.WORD: + return True, "" + return False, "cannot assign address to single register" + if isinstance(other, FloatValue): + range_error = check_value_in_range(self.datatype, self.register, 1, other.value) + if range_error: + return False, range_error + return True, "" + if self.datatype == DataType.BYTE: + if other.datatype != DataType.BYTE: + return False, "(unsigned) byte required" + return True, "" + if self.datatype == DataType.WORD: + if other.datatype in (DataType.BYTE, DataType.WORD) or other.datatype in STRING_DATATYPES: + return True, "" + return False, "(unsigned) byte, word or string required" + return False, "incompatible value for register assignment" + + +class MemMappedValue(Value): + def __init__(self, address: Optional[int], datatype: DataType, length: int, + sourceref: SourceRef, name: str = None, constant: bool = False) -> None: + super().__init__(datatype, sourceref, name, constant) + self.address = address + self.length = length + assert address is None or type(address) is int + + def __hash__(self): + return hash((self.datatype, self.address, self.length, self.name, self.constant)) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, MemMappedValue): + return NotImplemented + elif self is other: + return True + else: + return (other.datatype, other.address, other.length, other.name, other.constant) == \ + (self.datatype, self.address, self.length, self.name, self.constant) + + def __str__(self): + addr = "" if self.address is None else "${:04x}".format(self.address) + return "" \ + .format(addr, self.datatype, self.length, self.name, self.constant) + + def assignable_from(self, other: Value) -> Tuple[bool, str]: + if self.constant: + return False, "cannot assign to a constant" + if isinstance(other, IndirectValue): + if self.datatype == other.datatype: + return True, "" + return False, "data type of value and target are not the same" + if self.datatype == DataType.BYTE: + if isinstance(other, (IntegerValue, RegisterValue, MemMappedValue)): + if other.datatype == DataType.BYTE: + return True, "" + return False, "(unsigned) byte required" + elif isinstance(other, FloatValue): + range_error = check_value_in_range(self.datatype, "", 1, other.value) + if range_error: + return False, range_error + return True, "" + else: + return False, "(unsigned) byte required" + elif self.datatype in (DataType.WORD, DataType.FLOAT): + if isinstance(other, (IntegerValue, FloatValue)): + range_error = check_value_in_range(self.datatype, "", 1, other.value) + if range_error: + return False, range_error + return True, "" + elif isinstance(other, (RegisterValue, MemMappedValue)): + if other.datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): + return True, "" + else: + return False, "byte or word or float required" + elif isinstance(other, StringValue): + if self.datatype == DataType.WORD: + return True, "" + return False, "string address requires 16 bits (a word)" + if self.datatype == DataType.BYTE: + return False, "(unsigned) byte required" + if self.datatype == DataType.WORD: + return False, "(unsigned) word required" + return False, "incompatible value for assignment" + + +class Comment(AstNode): + def __init__(self, text: str, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.text = text + + +class Label(AstNode): + def __init__(self, name: str, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.name = name + + +class AssignmentStmt(AstNode): + def __init__(self, leftvalues: List[Value], right: Value, sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.leftvalues = leftvalues + self.right = right + + def __str__(self): + return "".format(str(self.right), ",".join(str(lv) for lv in self.leftvalues)) + + def remove_identity_lvalues(self) -> None: + for lv in self.leftvalues: + if lv == self.right: + print("{}: removed identity assignment".format(self.sourceref)) + remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right] + self.leftvalues = remaining_leftvalues + + def is_identity(self) -> bool: + return all(lv == self.right for lv in self.leftvalues) + + +class AugmentedAssignmentStmt(AssignmentStmt): + SUPPORTED_OPERATORS = {"+=", "-=", "&=", "|=", "^=", ">>=", "<<="} + + # full set: {"+=", "-=", "*=", "/=", "%=", "//=", "**=", "&=", "|=", "^=", ">>=", "<<="} + + def __init__(self, left: Value, operator: str, right: Value, sourceref: SourceRef) -> None: + assert operator in self.SUPPORTED_OPERATORS + super().__init__([left], right, sourceref) + self.operator = operator + + def __str__(self): + return "".format(str(self.leftvalues[0]), self.operator, str(self.right)) + + +class ReturnStmt(AstNode): + def __init__(self, sourceref: SourceRef, a: Optional[Value] = None, + x: Optional[Value] = None, + y: Optional[Value] = None) -> None: + super().__init__(sourceref) + self.a = a + self.x = x + self.y = y + + +class InplaceIncrStmt(AstNode): + def __init__(self, what: Value, value: Union[IntegerValue, FloatValue], sourceref: SourceRef) -> None: + super().__init__(sourceref) + assert value.constant + assert (value.value is None and value.name) or value.value > 0 + self.what = what + self.value = value + + +class InplaceDecrStmt(AstNode): + def __init__(self, what: Value, value: Union[IntegerValue, FloatValue], sourceref: SourceRef) -> None: + super().__init__(sourceref) + assert value.constant + assert (value.value is None and value.name) or value.value > 0 + self.what = what + self.value = value + + +class IfCondition(AstNode): + SWAPPED_OPERATOR = {"==": "==", + "!=": "!=", + "<=": ">=", + ">=": "<=", + "<": ">", + ">": "<"} + IF_STATUSES = {"cc", "cs", "vc", "vs", "eq", "ne", "true", "not", "zero", "pos", "neg", "lt", "gt", "le", "ge"} + + def __init__(self, ifstatus: str, leftvalue: Optional[Value], + operator: str, rightvalue: Optional[Value], sourceref: SourceRef) -> None: + assert ifstatus in self.IF_STATUSES + assert operator in (None, "") or operator in self.SWAPPED_OPERATOR + if operator: + assert ifstatus in ("true", "not", "zero") + super().__init__(sourceref) + self.ifstatus = ifstatus + self.lvalue = leftvalue + self.comparison_op = operator + self.rvalue = rightvalue + + def __str__(self): + return "".format(self.ifstatus, self.lvalue, self.comparison_op, self.rvalue) + + def make_if_true(self) -> bool: + # makes a condition of the form if_not a < b into: if a > b (gets rid of the not) + # returns whether the change was made or not + if self.ifstatus == "not" and self.comparison_op: + self.ifstatus = "true" + self.comparison_op = self.SWAPPED_OPERATOR[self.comparison_op] + return True + return False + + def swap(self) -> Tuple[Value, str, Value]: + self.lvalue, self.comparison_op, self.rvalue = self.rvalue, self.SWAPPED_OPERATOR[self.comparison_op], self.lvalue + return self.lvalue, self.comparison_op, self.rvalue + + +class CallStmt(AstNode): + def __init__(self, sourceref: SourceRef, target: Optional[Value] = None, *, + address: Optional[int] = None, arguments: List[Tuple[str, Any]] = None, + outputs: List[Tuple[str, Value]] = None, is_goto: bool = False, + preserve_regs: Set[str] = None, condition: IfCondition = None) -> None: + if not is_goto: + assert condition is None + super().__init__(sourceref) + self.target = target + self.address = address + self.arguments = arguments + self.outputvars = outputs + self.is_goto = is_goto + self.condition = condition + self.preserve_regs = preserve_regs + self.desugared_call_arguments = [] # type: List[AssignmentStmt] + self.desugared_output_assignments = [] # type: List[AssignmentStmt] + + +class InlineAsm(AstNode): + def __init__(self, asmlines: List[str], sourceref: SourceRef) -> None: + super().__init__(sourceref) + self.asmlines = asmlines + + +class BreakpointStmt(AstNode): + def __init__(self, sourceref: SourceRef) -> None: + super().__init__(sourceref) diff --git a/lib/mathlib.ill b/lib/mathlib.ill index 0aa648297..e594c9b34 100644 --- a/lib/mathlib.ill +++ b/lib/mathlib.ill @@ -173,4 +173,76 @@ remainder = SCRATCH_ZP1 } } + +sub randbyte () -> (A) { + ; ---- 8-bit pseudo random number generator into A + + %asm { + lda _seed + beq + + asl a + beq ++ ;if the input was $80, skip the EOR + bcc ++ ++ eor _magic ; #$1d ; could be self-modifying code to set new magic ++ sta _seed + rts + + _seed .byte $3a +_magic .byte $1d +_magiceors .byte $1d, $2b, $2d, $4d, $5f, $63, $65, $69 + .byte $71, $87, $8d, $a9, $c3, $cf, $e7, $f5 + + ;returns - this comment avoids compiler warning + } +} + +sub randword () -> (XY) { + ; ---- 16 bit pseudo random number generator into XY + + %asm { + lda _seed + beq _lowZero ; $0000 and $8000 are special values to test for + + ; Do a normal shift + asl _seed + lda _seed+1 + rol a + bcc _noEor + +_doEor ; high byte is in A + eor _magic+1 ; #>magic ; could be self-modifying code to set new magic + sta _seed+1 + lda _seed + eor _magic ; #>= var1