diff --git a/README.md b/README.md index f88f61380..a9760f60b 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ which aims to provide many conveniences over raw assembly code (even when using - breakpoints, that let the Vice emulator drop into the monitor if execution hits them - source code labels automatically loaded in Vice emulator so it can show them in disassembly - conditional gotos -- some code optimizations (such as not repeatedly loading the same value in a register) +- various code optimizations (code structure, logical and numerical expressions, ...) - @todo: loops - @todo: memory block operations diff --git a/il65/compile.py b/il65/compile.py index 9d2fcb618..3e07b2c86 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -11,10 +11,7 @@ import sys import linecache from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check import attr -from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \ - SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\ - SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \ - check_symbol_definition, UndefinedSymbolError, process_expression, AugAssignment +from .plyparse import * from .plylex import SourceRef, print_bold from .datatypes import DataType, VarType @@ -38,7 +35,7 @@ class PlyParser: self.check_all_symbolnames(module) self.create_multiassigns(module) self.check_and_merge_zeropages(module) - self.process_all_expressions(module) + self.simplify_some_assignments(module) if not self.imported_module: # the following shall only be done on the main module after all imports have been done: self.apply_directive_options(module) @@ -75,7 +72,14 @@ class PlyParser: # perform semantic analysis / checks on the syntactic parse tree we have so far # (note: symbol names have already been checked to exist when we start this) previous_stmt = None + encountered_blocks = set() # type: Set[Block] for node in module.all_nodes(): + if isinstance(node, Block): + parentname = (node.parent.name + ".") if node.parent else "" + blockname = parentname + node.name + if blockname in encountered_blocks: + raise ValueError("block names not unique:", blockname) + encountered_blocks.add(blockname) if isinstance(node, Scope): if node.nodes and isinstance(node.parent, (Block, Subroutine)): if isinstance(node.parent, Block) and node.parent.name != "ZP": @@ -170,27 +174,11 @@ class PlyParser: for node in module.all_nodes(SymbolName): check_symbol_definition(node.name, node.my_scope(), node.sourceref) # type: ignore - def process_all_expressions(self, module: Module) -> None: - # process/simplify all expressions (constant folding etc) - encountered_blocks = set() # type: Set[Block] + def simplify_some_assignments(self, module: Module) -> None: + # simplify some assignment statements, + # note taht most of the expression optimization (constant folding etc) is done in the optimizer. for node in module.all_nodes(): - if isinstance(node, Block): - parentname = (node.parent.name + ".") if node.parent else "" - blockname = parentname + node.name - if blockname in encountered_blocks: - raise ValueError("block names not unique:", blockname) - encountered_blocks.add(blockname) - elif isinstance(node, Expression): - try: - evaluated = process_expression(node, node.sourceref) - if evaluated is not node: - # replace the node with the newly evaluated result - node.parent.replace_node(node, evaluated) - except ParseError: - raise - except Exception as x: - self.handle_internal_error(x, "process_expressions of node {}".format(node)) - elif isinstance(node, IncrDecr) and node.howmuch not in (0, 1): + if isinstance(node, IncrDecr) and node.howmuch not in (0, 1): _, node.howmuch = coerce_constant_value(datatype_of(node.target, node.my_scope()), node.howmuch, node.sourceref) attr.validate(node) elif isinstance(node, VarDef): @@ -485,17 +473,6 @@ class PlyParser: print("\x1b[0m", file=out, end="", flush=True) raise exc # XXX temporary to see where the error occurred - def handle_internal_error(self, exc: Exception, msg: str="") -> None: - out = sys.stdout - if out.isatty(): - print("\x1b[1m", file=out) - print("\nERROR: internal parser error: ", exc, file=out) - if msg: - print(" Message:", msg, end="\n\n") - if out.isatty(): - print("\x1b[0m", file=out, end="", flush=True) - raise exc - class Zeropage: SCRATCH_B1 = 0x02 diff --git a/il65/datatypes.py b/il65/datatypes.py index eaf8e33f7..cc75757fc 100644 --- a/il65/datatypes.py +++ b/il65/datatypes.py @@ -5,11 +5,8 @@ Here are the data type definitions and -conversions. Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ -import math import enum -from typing import Tuple, Union from functools import total_ordering -from .plylex import print_warning, SourceRef @total_ordering diff --git a/il65/emit/incrdecr.py b/il65/emit/incrdecr.py index 619e53c36..59fd626ef 100644 --- a/il65/emit/incrdecr.py +++ b/il65/emit/incrdecr.py @@ -193,11 +193,13 @@ def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None: out("\vjsr c64flt.float_add_one") else: out("\vjsr c64flt.float_sub_one") - elif NOTYETIMPLEMENTED: # XXX for the float += otherfloat cases + else: + # XXX for the float += otherfloat cases + print("FLOAT INCR/DECR BY", stmt.howmuch) # XXX with preserving_registers({'A', 'X', 'Y'}, scope, out, loads_a_within=True): - out("\vlda #<" + stmt.value.name) + # XXX out("\vlda #<" + stmt.value.name) out("\vsta c64.SCRATCH_ZPWORD1") - out("\vlda #>" + stmt.value.name) + # XXX out("\vlda #>" + stmt.value.name) out("\vsta c64.SCRATCH_ZPWORD1+1") out("\vldx #<" + what_str) out("\vldy #>" + what_str) @@ -205,8 +207,6 @@ def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None: out("\vjsr c64flt.float_add_SW1_to_XY") else: out("\vjsr c64flt.float_sub_SW1_from_XY") - else: - raise CodeError("incr/decr missing float constant definition") else: raise CodeError("cannot in/decrement memory of type " + str(target.datatype), stmt.howmuch) diff --git a/il65/emit/variables.py b/il65/emit/variables.py index d42295f1d..02b4d799a 100644 --- a/il65/emit/variables.py +++ b/il65/emit/variables.py @@ -129,6 +129,7 @@ def generate_block_vars(out: Callable, block: Block, zeropage: bool=False) -> No _generate_string_var(out, vardef) else: raise CodeError("invalid const type", vardef) + # @todo float constants that are used in expressions out("; memory mapped variables") for vardef in vars_by_vartype.get(VarType.MEMORY, []): # create a definition for variables at a specific place in memory (memory-mapped) diff --git a/il65/main.py b/il65/main.py index 38b5ab597..d730ac526 100644 --- a/il65/main.py +++ b/il65/main.py @@ -83,7 +83,7 @@ def main() -> None: parsed_module = parser.parse_file(args.sourcefile) if parsed_module: if args.nooptimize: - print_bold("not optimizing the parse tree!") + print_bold("Optimizations disabled!") else: print("\nOptimizing code.") optimize(parsed_module) diff --git a/il65/optimize.py b/il65/optimize.py index 2bb121955..bccd8472d 100644 --- a/il65/optimize.py +++ b/il65/optimize.py @@ -6,21 +6,34 @@ eliminates statements that have no effect, optimizes calculations etc. Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ -from typing import List, no_type_check, Union -from .plyparse import AstNode, Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr,\ - datatype_of, coerce_constant_value, AssignmentTargets, LiteralValue, Scope, Register, SymbolName, \ - Dereference, TargetRegisters, VarDef +import sys +from typing import List, no_type_check, Union, Any +from .plyparse import * from .plylex import print_warning, print_bold, SourceRef -from .datatypes import DataType +from .datatypes import DataType, VarType class Optimizer: def __init__(self, mod: Module) -> None: self.num_warnings = 0 self.module = mod + self.optimizations_performed = False def optimize(self) -> None: self.num_warnings = 0 + self.optimizations_performed = True + # keep optimizing as long as there were changes made + while self.optimizations_performed: + self.optimizations_performed = False + self._optimize() + # remaining optimizations that have to be done just once: + self.remove_unused_subroutines() + self.remove_empty_blocks() + + def _optimize(self) -> None: + self.constant_folding() + # @todo expression optimization: reduce expression nesting + # @todo expression optimization: simplify logical expression when a term makes it always true or false self.create_aug_assignments() self.optimize_assignments() self.remove_superfluous_assignments() @@ -28,11 +41,32 @@ class Optimizer: self.optimize_multiassigns() # @todo optimize some simple multiplications into shifts (A*=8 -> A<<3) # @todo optimize addition with self into shift 1 (A+=A -> A<<=1) - self.remove_unused_subroutines() self.optimize_goto_compare_with_zero() self.join_incrdecrs() # @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to) - self.remove_empty_blocks() + + def handle_internal_error(self, exc: Exception, msg: str="") -> None: + out = sys.stdout + if out.isatty(): + print("\x1b[1m", file=out) + print("\nERROR: internal parser/optimizer error: ", exc, file=out) + if msg: + print(" Message:", msg, end="\n\n") + if out.isatty(): + print("\x1b[0m", file=out, end="", flush=True) + raise exc + + def constant_folding(self) -> None: + for expression in self.module.all_nodes(Expression): + try: + evaluated = process_expression(expression, expression.sourceref) # type: ignore + if evaluated is not expression: + # replace the node with the newly evaluated result + expression.parent.replace_node(expression, evaluated) + except ParseError: + raise + except Exception as x: + self.handle_internal_error(x, "process_expressions of node {}".format(expression)) def join_incrdecrs(self) -> None: for scope in self.module.all_nodes(Scope): @@ -88,6 +122,7 @@ class Optimizer: incrdecr = self._make_incrdecr(incrdecrs[0], target, total, "--") scope.replace_node(incrdecrs[0], incrdecr) if replaced: + self.optimizations_performed = True self.num_warnings += 1 print_warning("{}: merged a sequence of incr/decrs or augmented assignments".format(incrdecrs[0].sourceref)) incrdecrs.clear() @@ -122,6 +157,7 @@ class Optimizer: operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) + self.optimizations_performed = True continue if expr.operator not in ('+', '*', '|', '^'): # associative operators continue @@ -130,11 +166,13 @@ class Optimizer: operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) + self.optimizations_performed = True elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right): num_val = expr.left.const_num_val() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) + self.optimizations_performed = True def remove_superfluous_assignments(self) -> None: # remove consecutive assignment statements to the same target, only keep the last value (only if its a constant!) @@ -146,6 +184,7 @@ class Optimizer: if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left): if not node.left.has_memvalue(): scope.remove_node(prev_node) + self.optimizations_performed = True self.num_warnings += 1 print_warning("{}: removed superfluous assignment".format(prev_node.sourceref)) prev_node = node @@ -160,6 +199,7 @@ class Optimizer: if isinstance(assignment, Assignment): if all(lv == assignment.right for lv in assignment.left.nodes): assignment.my_scope().remove_node(assignment) + self.optimizations_performed = True self.num_warnings += 1 print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) if isinstance(assignment, AugAssignment): @@ -169,22 +209,26 @@ class Optimizer: self.num_warnings += 1 print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) assignment.my_scope().remove_node(assignment) + self.optimizations_performed = True elif assignment.operator == "*=": self.num_warnings += 1 print_warning("{}: statement replaced by = 0".format(assignment.sourceref)) new_assignment = self._make_new_assignment(assignment, 0) assignment.my_scope().replace_node(assignment, new_assignment) + self.optimizations_performed = True elif assignment.operator == "**=": self.num_warnings += 1 print_warning("{}: statement replaced by = 1".format(assignment.sourceref)) new_assignment = self._make_new_assignment(assignment, 1) assignment.my_scope().replace_node(assignment, new_assignment) + self.optimizations_performed = True if assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="): print("{}: shifting result is always zero".format(assignment.sourceref)) new_stmt = Assignment(sourceref=assignment.sourceref) new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref)) new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref)) assignment.my_scope().replace_node(assignment, new_stmt) + self.optimizations_performed = True if assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256: howmuch = assignment.right if howmuch.value not in (0, 1): @@ -194,10 +238,12 @@ class Optimizer: howmuch=howmuch.value, sourceref=assignment.sourceref) new_stmt.target = assignment.left assignment.my_scope().replace_node(assignment, new_stmt) + self.optimizations_performed = True if assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="): self.num_warnings += 1 print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) assignment.my_scope().remove_node(assignment) + self.optimizations_performed = True @no_type_check def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment: @@ -222,7 +268,7 @@ class Optimizer: @no_type_check def _make_incrdecr(self, old_stmt: AstNode, target: Union[TargetRegisters, Register, SymbolName, Dereference], - howmuch: Union[int, float], operator: str) -> AugAssignment: + howmuch: Union[int, float], operator: str) -> IncrDecr: a = IncrDecr(operator=operator, howmuch=howmuch, sourceref=old_stmt.sourceref) a.nodes.append(target) a.parent = old_stmt.parent @@ -245,6 +291,7 @@ class Optimizer: print("{}: joined with previous assignment".format(assignment.sourceref)) assignments[0].left.nodes.extend(assignment.left.nodes) scope.remove_node(assignment) + self.optimizations_performed = True rvalue = None assignments.clear() else: @@ -263,8 +310,9 @@ class Optimizer: lvalues = set(assignment.left.nodes) if len(lvalues) != len(assignment.left.nodes): print("{}: removed duplicate assignment targets".format(assignment.sourceref)) - # @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) - assignment.left.nodes = list(lvalues) + # @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) + assignment.left.nodes = list(lvalues) + self.optimizations_performed = True @no_type_check def remove_unused_subroutines(self) -> None: @@ -337,6 +385,155 @@ class Optimizer: node.my_scope().nodes.remove(node) +def process_expression(expr: Expression, sourceref: SourceRef) -> Any: + # process/simplify all expressions (constant folding etc) + if expr.must_be_constant: + return process_constant_expression(expr, sourceref) + else: + return process_dynamic_expression(expr, sourceref) + + +def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue: + # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. + if isinstance(expr, (int, float, str, bool)): + raise TypeError("expr node should not be a python primitive value", expr, sourceref) + elif expr is None or isinstance(expr, LiteralValue): + return expr + elif isinstance(expr, SymbolName): + value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) + if isinstance(value, VarDef): + if value.vartype == VarType.MEMORY: + raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) + value = value.value + if isinstance(value, Expression): + raise ExpressionEvaluationError("circular reference?", expr.sourceref) + elif isinstance(value, LiteralValue): + return value + elif isinstance(value, (int, float, str, bool)): + raise TypeError("symbol value node should not be a python primitive value", expr) + else: + raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref) + elif isinstance(expr, AddressOf): + assert isinstance(expr.name, SymbolName) + value = check_symbol_definition(expr.name.name, expr.my_scope(), expr.sourceref) + if isinstance(value, VarDef): + if value.vartype == VarType.MEMORY: + if isinstance(value.value, LiteralValue): + return value.value + else: + raise ExpressionEvaluationError("constant literal value required", value.sourceref) + if value.vartype == VarType.CONST: + raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref) + raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant" + .format(value.__class__.__name__), expr.name.sourceref) + else: + raise ExpressionEvaluationError("constant address required, not {}" + .format(value.__class__.__name__), expr.name.sourceref) + elif isinstance(expr, SubCall): + if isinstance(expr.target, SymbolName): # 'function(1,2,3)' + funcname = expr.target.name + if funcname in math_functions or funcname in builtin_functions: + func_args = [] + for a in (process_constant_expression(callarg.value, sourceref) for callarg in expr.arguments.nodes): + if isinstance(a, LiteralValue): + func_args.append(a.value) + else: + func_args.append(a) + func = math_functions.get(funcname, builtin_functions.get(funcname)) + try: + return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore + except Exception as x: + raise ExpressionEvaluationError(str(x), expr.sourceref) + else: + raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) + elif isinstance(expr.target, Dereference): # '[...](1,2,3)' + raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) + elif type(expr.target) is int: # '64738()' + raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) + else: + raise NotImplementedError("weird call target", expr.target) + elif not isinstance(expr, Expression): + raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) + if expr.unary: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_constant_expression(expr.left, left_sourceref) + if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): + try: + if expr.operator == '-': + return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore + elif expr.operator == '~': + return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore + elif expr.operator in ("++", "--"): + raise ValueError("incr/decr should not be an expression") + raise ValueError("invalid unary operator", expr.operator) + except TypeError as x: + raise ParseError(str(x), expr.sourceref) from None + raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) + else: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_constant_expression(expr.left, left_sourceref) + right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref + expr.right = process_constant_expression(expr.right, right_sourceref) + if isinstance(expr.left, LiteralValue): + if isinstance(expr.right, LiteralValue): + return expr.evaluate_primitive_constants(expr.right.sourceref) + else: + raise ExpressionEvaluationError("constant literal value required on right, not {}" + .format(expr.right.__class__.__name__), right_sourceref) + else: + raise ExpressionEvaluationError("constant literal value required on left, not {}" + .format(expr.left.__class__.__name__), left_sourceref) + + +def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any: + # constant-fold a dynamic expression + if isinstance(expr, (int, float, str, bool)): + raise TypeError("expr node should not be a python primitive value", expr, sourceref) + elif expr is None or isinstance(expr, LiteralValue): + return expr + elif isinstance(expr, SymbolName): + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr + elif isinstance(expr, AddressOf): + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr + elif isinstance(expr, SubCall): + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + if isinstance(expr.target, SymbolName): + check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref) + return expr + elif isinstance(expr, Register): + return expr + elif isinstance(expr, Dereference): + if isinstance(expr.operand, SymbolName): + check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref) + return expr + elif not isinstance(expr, Expression): + raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) + if expr.unary: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_dynamic_expression(expr.left, left_sourceref) + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr + else: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_dynamic_expression(expr.left, left_sourceref) + right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref + expr.right = process_dynamic_expression(expr.right, right_sourceref) + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr + + def optimize(mod: Module) -> None: opt = Optimizer(mod) opt.optimize() diff --git a/il65/plyparse.py b/il65/plyparse.py index 469b55903..e5419a07b 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -10,7 +10,7 @@ import builtins import inspect import enum from collections import defaultdict -from typing import Union, Generator, Tuple, Sequence, List, Optional, Dict, Any, no_type_check +from typing import Union, Generator, Tuple, List, Optional, Dict, Any, no_type_check import attr from ply.yacc import yacc from .plylex import SourceRef, tokens, lexer, find_tok_column, print_warning @@ -809,155 +809,6 @@ def coerce_constant_value(datatype: DataType, value: AstNode, return False, value -def process_expression(expr: Expression, sourceref: SourceRef) -> Any: - # process/simplify all expressions (constant folding etc) - if expr.must_be_constant: - return process_constant_expression(expr, sourceref) - else: - return process_dynamic_expression(expr, sourceref) - - -def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue: - # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. - if isinstance(expr, (int, float, str, bool)): - raise TypeError("expr node should not be a python primitive value", expr, sourceref) - elif expr is None or isinstance(expr, LiteralValue): - return expr - elif isinstance(expr, SymbolName): - value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) - if isinstance(value, VarDef): - if value.vartype == VarType.MEMORY: - raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) - value = value.value - if isinstance(value, Expression): - raise ExpressionEvaluationError("circular reference?", expr.sourceref) - elif isinstance(value, LiteralValue): - return value - elif isinstance(value, (int, float, str, bool)): - raise TypeError("symbol value node should not be a python primitive value", expr) - else: - raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref) - elif isinstance(expr, AddressOf): - assert isinstance(expr.name, SymbolName) - value = check_symbol_definition(expr.name.name, expr.my_scope(), expr.sourceref) - if isinstance(value, VarDef): - if value.vartype == VarType.MEMORY: - if isinstance(value.value, LiteralValue): - return value.value - else: - raise ExpressionEvaluationError("constant literal value required", value.sourceref) - if value.vartype == VarType.CONST: - raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref) - raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant" - .format(value.__class__.__name__), expr.name.sourceref) - else: - raise ExpressionEvaluationError("constant address required, not {}" - .format(value.__class__.__name__), expr.name.sourceref) - elif isinstance(expr, SubCall): - if isinstance(expr.target, SymbolName): # 'function(1,2,3)' - funcname = expr.target.name - if funcname in math_functions or funcname in builtin_functions: - func_args = [] - for a in (process_constant_expression(callarg.value, sourceref) for callarg in expr.arguments.nodes): - if isinstance(a, LiteralValue): - func_args.append(a.value) - else: - func_args.append(a) - func = math_functions.get(funcname, builtin_functions.get(funcname)) - try: - return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore - except Exception as x: - raise ExpressionEvaluationError(str(x), expr.sourceref) - else: - raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) - elif isinstance(expr.target, Dereference): # '[...](1,2,3)' - raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) - elif type(expr.target) is int: # '64738()' - raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) - else: - raise NotImplementedError("weird call target", expr.target) - elif not isinstance(expr, Expression): - raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) - if expr.unary: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_constant_expression(expr.left, left_sourceref) - if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): - try: - if expr.operator == '-': - return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore - elif expr.operator == '~': - return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore - elif expr.operator in ("++", "--"): - raise ValueError("incr/decr should not be an expression") - raise ValueError("invalid unary operator", expr.operator) - except TypeError as x: - raise ParseError(str(x), expr.sourceref) from None - raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) - else: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_constant_expression(expr.left, left_sourceref) - right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref - expr.right = process_constant_expression(expr.right, right_sourceref) - if isinstance(expr.left, LiteralValue): - if isinstance(expr.right, LiteralValue): - return expr.evaluate_primitive_constants(expr.right.sourceref) - else: - raise ExpressionEvaluationError("constant literal value required on right, not {}" - .format(expr.right.__class__.__name__), right_sourceref) - else: - raise ExpressionEvaluationError("constant literal value required on left, not {}" - .format(expr.left.__class__.__name__), left_sourceref) - - -def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any: - # constant-fold a dynamic expression - if isinstance(expr, (int, float, str, bool)): - raise TypeError("expr node should not be a python primitive value", expr, sourceref) - elif expr is None or isinstance(expr, LiteralValue): - return expr - elif isinstance(expr, SymbolName): - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr - elif isinstance(expr, AddressOf): - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr - elif isinstance(expr, SubCall): - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - if isinstance(expr.target, SymbolName): - check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref) - return expr - elif isinstance(expr, Register): - return expr - elif isinstance(expr, Dereference): - if isinstance(expr.operand, SymbolName): - check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref) - return expr - elif not isinstance(expr, Expression): - raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) - if expr.unary: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_dynamic_expression(expr.left, left_sourceref) - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr - else: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_dynamic_expression(expr.left, left_sourceref) - right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref - expr.right = process_dynamic_expression(expr.right, right_sourceref) - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr - - def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any: try: return scope.lookup(name) diff --git a/reference.md b/reference.md index 994b70e6a..43d2dbed6 100644 --- a/reference.md +++ b/reference.md @@ -25,7 +25,7 @@ which aims to provide many conveniences over raw assembly code (even when using - breakpoints, that let the Vice emulator drop into the monitor if execution hits them - source code labels automatically loaded in Vice emulator so it can show them in disassembly - conditional gotos -- some code optimizations (such as not repeatedly loading the same value in a register) +- various code optimizations (code structure, logical and numerical expressions, ...) - @todo: loops - @todo: memory block operations