""" Programming Language for 6502/6510 microprocessors, codename 'Sick' This is the part of the compiler/optimizer that simplifies expressions by doing 'constant folding' - replacing expressions with constant, compile-time precomputed values. Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ import sys from .plylex import SourceRef from .datatypes import VarType from .plyparse import * def handle_internal_error(exc: Exception, msg: str = "") -> None: out = sys.stdout if out.isatty(): print("\x1b[1m", file=out) print("\nERROR: internal parser/optimizer error: ", exc, file=out) if msg: print(" Message:", msg, end="\n\n") if out.isatty(): print("\x1b[0m", file=out, end="", flush=True) raise exc class ConstantFold: def __init__(self, mod: Module) -> None: self.num_warnings = 0 self.module = mod self.optimizations_performed = False def fold_constants(self, once: bool=False) -> None: self.num_warnings = 0 if once: self._constant_folding() else: self.optimizations_performed = True # keep optimizing as long as there were changes made while self.optimizations_performed: self.optimizations_performed = False self._constant_folding() def _constant_folding(self) -> None: for expression in self.module.all_nodes(Expression): if expression.parent is None or expression.parent.parent is None: # stale expression node (was part of an expression that was constant-folded away) continue if isinstance(expression, LiteralValue): continue try: evaluated = self._process_expression(expression) # type: ignore if evaluated is not expression: # replace the node with the newly evaluated result parent = expression.parent parent.replace_node(expression, evaluated) self.optimizations_performed = True except ParseError: raise except Exception as x: handle_internal_error(x, "process_expressions of node {}".format(expression)) def _process_expression(self, expr: Expression) -> Expression: # process/simplify all expressions (constant folding etc) if expr.is_lhs: if isinstance(expr, (Register, SymbolName, Dereference)): return expr raise ParseError("invalid lhs expression type", expr.sourceref) result = None # type: Expression if expr.is_compiletime_const(): result = self._process_constant_expression(expr, expr.sourceref) else: result = self._process_dynamic_expression(expr, expr.sourceref) result.parent = expr.parent return result def _process_constant_expression(self, expr: Expression, sourceref: SourceRef) -> LiteralValue: # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. if isinstance(expr, LiteralValue): return expr try: return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore except NotCompiletimeConstantError: pass if isinstance(expr, SymbolName): value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) value = value.value if isinstance(value, ExpressionWithOperator): raise ExpressionEvaluationError("circular reference?", expr.sourceref) elif isinstance(value, LiteralValue): return value elif isinstance(value, (int, float, str, bool)): raise TypeError("symbol value node should not be a python primitive value", expr) else: raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref) elif isinstance(expr, AddressOf): assert isinstance(expr.name, str) value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: if isinstance(value.value, LiteralValue): return value.value else: raise ExpressionEvaluationError("constant literal value required", value.sourceref) if value.vartype == VarType.CONST: raise ExpressionEvaluationError("can't take the address of a constant", expr.sourceref) raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant" .format(value.__class__.__name__), expr.sourceref) else: raise ExpressionEvaluationError("constant address required, not {}" .format(value.__class__.__name__), expr.sourceref) elif isinstance(expr, SubCall): if isinstance(expr.target, SymbolName): # 'function(1,2,3)' funcname = expr.target.name if funcname in math_functions or funcname in builtin_functions: func_args = [] for a in (self._process_constant_expression(callarg.value, sourceref) for callarg in list(expr.arguments.nodes)): if isinstance(a, LiteralValue): func_args.append(a.value) else: func_args.append(a) func = math_functions.get(funcname, builtin_functions.get(funcname)) try: return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore except Exception as x: raise ExpressionEvaluationError(str(x), expr.sourceref) else: raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref) elif isinstance(expr.target, Dereference): # '[...](1,2,3)' raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref) elif isinstance(expr.target, LiteralValue) and type(expr.target.value) is int: # '64738()' raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) else: raise NotImplementedError("weird call target", expr.target) elif isinstance(expr, ExpressionWithOperator): if expr.unary: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_constant_expression(expr.left, left_sourceref) expr.left.parent = expr if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): try: if expr.operator == '-': return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore elif expr.operator == '~': return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore elif expr.operator in ("++", "--"): raise ValueError("incr/decr should not be an expression") raise ValueError("invalid unary operator", expr.operator) except TypeError as x: raise ParseError(str(x), expr.sourceref) from None raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) else: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_constant_expression(expr.left, left_sourceref) expr.left.parent = expr right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = self._process_constant_expression(expr.right, right_sourceref) expr.right.parent = expr if isinstance(expr.left, LiteralValue): if isinstance(expr.right, LiteralValue): return expr.evaluate_primitive_constants(expr.right.sourceref) else: raise ExpressionEvaluationError("constant literal value required on right, not {}" .format(expr.right.__class__.__name__), right_sourceref) else: raise ExpressionEvaluationError("constant literal value required on left, not {}" .format(expr.left.__class__.__name__), left_sourceref) else: raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) def _process_dynamic_expression(self, expr: Expression, sourceref: SourceRef) -> Expression: # constant-fold a dynamic expression if isinstance(expr, LiteralValue): return expr try: return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore except NotCompiletimeConstantError: pass if isinstance(expr, SymbolName): try: return self._process_constant_expression(expr, sourceref) except (ExpressionEvaluationError, NotCompiletimeConstantError): return expr elif isinstance(expr, AddressOf): try: return self._process_constant_expression(expr, sourceref) except (ExpressionEvaluationError, NotCompiletimeConstantError): return expr elif isinstance(expr, SubCall): try: return self._process_constant_expression(expr, sourceref) except (ExpressionEvaluationError, NotCompiletimeConstantError): if isinstance(expr.target, SymbolName): check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref) return expr elif isinstance(expr, (Register, Dereference)): return expr elif isinstance(expr, ExpressionWithOperator): if expr.unary: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_dynamic_expression(expr.left, left_sourceref) expr.left.parent = expr try: return self._process_constant_expression(expr, sourceref) except (ExpressionEvaluationError, NotCompiletimeConstantError): return expr else: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_dynamic_expression(expr.left, left_sourceref) expr.left.parent = expr right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = self._process_dynamic_expression(expr.right, right_sourceref) expr.right.parent = expr try: return self._process_constant_expression(expr, sourceref) except (ExpressionEvaluationError, NotCompiletimeConstantError): return expr else: raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)