constant folding is now also done in regular compiler not only when optimizing (it's too valuable to skip)

This commit is contained in:
Irmen de Jong 2018-02-07 23:29:56 +01:00
parent 7b2af25a42
commit ca5f2f3607
4 changed files with 236 additions and 200 deletions

View File

@ -11,9 +11,10 @@ import sys
import linecache import linecache
from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check
import attr import attr
from .plyparse import *
from .plylex import SourceRef, print_bold
from .datatypes import DataType, VarType from .datatypes import DataType, VarType
from .plylex import SourceRef, print_bold
from .expressions import ExpressionOptimizer
from .plyparse import *
class CompileError(Exception): class CompileError(Exception):
@ -42,6 +43,8 @@ class PlyParser:
self.check_all_symbolnames(module) self.check_all_symbolnames(module)
self.determine_subroutine_usage(module) self.determine_subroutine_usage(module)
self.all_parents_connected(module) self.all_parents_connected(module)
eo = ExpressionOptimizer(module)
eo.optimize() # do some constant-folding
self.semantic_check(module) self.semantic_check(module)
self.coerce_values(module) self.coerce_values(module)
self.check_floats_enabled(module) self.check_floats_enabled(module)

220
il65/expressions.py Normal file
View File

@ -0,0 +1,220 @@
"""
Programming Language for 6502/6510 microprocessors, codename 'Sick'
This is the part of the compiler/optimizer that simplifies/evaluates expressions.
Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
"""
import sys
from .plylex import SourceRef
from .datatypes import VarType
from .plyparse import Module, Expression, LiteralValue, SymbolName, ParseError, VarDef, Dereference, Register,\
SubCall, AddressOf, AstNode, ExpressionWithOperator, ExpressionEvaluationError, \
math_functions, builtin_functions, check_symbol_definition
def handle_internal_error(exc: Exception, msg: str = "") -> None:
out = sys.stdout
if out.isatty():
print("\x1b[1m", file=out)
print("\nERROR: internal parser/optimizer error: ", exc, file=out)
if msg:
print(" Message:", msg, end="\n\n")
if out.isatty():
print("\x1b[0m", file=out, end="", flush=True)
raise exc
class ExpressionOptimizer:
def __init__(self, mod: Module) -> None:
self.num_warnings = 0
self.module = mod
self.optimizations_performed = False
def optimize(self, once: bool=False) -> None:
self.num_warnings = 0
if once:
self.constant_folding()
else:
self.optimizations_performed = True
# keep optimizing as long as there were changes made
while self.optimizations_performed:
self.optimizations_performed = False
self.constant_folding()
def constant_folding(self) -> None:
for expression in self.module.all_nodes(Expression):
if isinstance(expression, LiteralValue):
continue
try:
evaluated = self.process_expression(expression) # type: ignore
if evaluated is not expression:
# replace the node with the newly evaluated result
parent = expression.parent
parent.replace_node(expression, evaluated)
self.optimizations_performed = True
except ParseError:
raise
except Exception as x:
handle_internal_error(x, "process_expressions of node {}".format(expression))
def process_expression(self, expr: Expression) -> Expression:
# process/simplify all expressions (constant folding etc)
result = None # type: Expression
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
result = self._process_constant_expression(expr, expr.sourceref)
else:
result = self._process_dynamic_expression(expr, expr.sourceref)
result.parent = expr.parent
return result
def _process_constant_expression(self, expr: Expression, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, ExpressionWithOperator):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue):
return value
elif isinstance(value, (int, float, str, bool)):
raise TypeError("symbol value node should not be a python primitive value", expr)
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, AddressOf):
assert isinstance(expr.name, str)
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
if isinstance(value.value, LiteralValue):
return value.value
else:
raise ExpressionEvaluationError("constant literal value required", value.sourceref)
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, SubCall):
if isinstance(expr.target, SymbolName): # 'function(1,2,3)'
funcname = expr.target.name
if funcname in math_functions or funcname in builtin_functions:
func_args = []
for a in (self._process_constant_expression(callarg.value, sourceref) for callarg in list(expr.arguments.nodes)):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref)
elif isinstance(expr.target, Dereference): # '[...](1,2,3)'
raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref)
elif isinstance(expr.target, LiteralValue) and type(expr.target.value) is int: # '64738()'
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else:
raise NotImplementedError("weird call target", expr.target)
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_constant_expression(expr.left, left_sourceref)
expr.left.parent = expr
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_constant_expression(expr.left, left_sourceref)
expr.left.parent = expr
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = self._process_constant_expression(expr.right, right_sourceref)
expr.right.parent = expr
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
else:
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
def _process_dynamic_expression(self, expr: Expression, sourceref: SourceRef) -> Expression:
# constant-fold a dynamic expression
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
elif isinstance(expr, AddressOf):
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
elif isinstance(expr, SubCall):
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr
elif isinstance(expr, (Register, Dereference)):
return expr
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = self._process_dynamic_expression(expr.right, right_sourceref)
expr.right.parent = expr
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
else:
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)

View File

@ -6,11 +6,13 @@ eliminates statements that have no effect, optimizes calculations etc.
Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
import sys from typing import List, no_type_check, Union
from typing import List, no_type_check, Union, Any from .datatypes import DataType
from .plyparse import * from .plyparse import Module, Block, Scope, IncrDecr, AstNode, Register, TargetRegisters, Assignment, AugAssignment, \
from .plylex import print_warning, print_bold, SourceRef AssignmentTargets, SymbolName, VarDef, Dereference, LiteralValue, ExpressionWithOperator, Subroutine, \
from .datatypes import DataType, VarType Goto, Expression, Directive, coerce_constant_value, datatype_of
from .plylex import print_warning, print_bold
from .expressions import ExpressionOptimizer
class Optimizer: class Optimizer:
@ -18,6 +20,7 @@ class Optimizer:
self.num_warnings = 0 self.num_warnings = 0
self.module = mod self.module = mod
self.optimizations_performed = False self.optimizations_performed = False
self.simple_expression_optimizer = ExpressionOptimizer(self.module)
def optimize(self) -> None: def optimize(self) -> None:
self.num_warnings = 0 self.num_warnings = 0
@ -31,7 +34,7 @@ class Optimizer:
self.remove_empty_blocks() self.remove_empty_blocks()
def _optimize(self) -> None: def _optimize(self) -> None:
self.constant_folding() self.simple_expression_optimizer.optimize(True) # perform constant folding and simple expression optimization
# @todo expression optimization: reduce expression nesting / flattening of parenthesis # @todo expression optimization: reduce expression nesting / flattening of parenthesis
# @todo expression optimization: simplify logical expression when a term makes it always true or false # @todo expression optimization: simplify logical expression when a term makes it always true or false
# @todo expression optimization: optimize some simple multiplications into shifts (A*=8 -> A<<3) # @todo expression optimization: optimize some simple multiplications into shifts (A*=8 -> A<<3)
@ -47,33 +50,6 @@ class Optimizer:
# @todo remove loops with conditions that are always empty/false # @todo remove loops with conditions that are always empty/false
# @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to) # @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to)
def handle_internal_error(self, exc: Exception, msg: str="") -> None:
out = sys.stdout
if out.isatty():
print("\x1b[1m", file=out)
print("\nERROR: internal parser/optimizer error: ", exc, file=out)
if msg:
print(" Message:", msg, end="\n\n")
if out.isatty():
print("\x1b[0m", file=out, end="", flush=True)
raise exc
def constant_folding(self) -> None:
for expression in self.module.all_nodes(Expression):
if isinstance(expression, LiteralValue):
continue
try:
evaluated = process_expression(expression) # type: ignore
if evaluated is not expression:
# replace the node with the newly evaluated result
parent = expression.parent
parent.replace_node(expression, evaluated)
self.optimizations_performed = True
except ParseError:
raise
except Exception as x:
self.handle_internal_error(x, "process_expressions of node {}".format(expression))
def join_incrdecrs(self) -> None: def join_incrdecrs(self) -> None:
for scope in self.module.all_nodes(Scope): for scope in self.module.all_nodes(Scope):
incrdecrs = [] # type: List[IncrDecr] incrdecrs = [] # type: List[IncrDecr]
@ -399,170 +375,6 @@ class Optimizer:
node.my_scope().nodes.remove(node) node.my_scope().nodes.remove(node)
def process_expression(expr: Expression) -> Expression:
# process/simplify all expressions (constant folding etc)
result = None # type: Expression
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
result = _process_constant_expression(expr, expr.sourceref)
else:
result = _process_dynamic_expression(expr, expr.sourceref)
result.parent = expr.parent
return result
def _process_constant_expression(expr: Expression, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, ExpressionWithOperator):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue):
return value
elif isinstance(value, (int, float, str, bool)):
raise TypeError("symbol value node should not be a python primitive value", expr)
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, AddressOf):
assert isinstance(expr.name, str)
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
if isinstance(value.value, LiteralValue):
return value.value
else:
raise ExpressionEvaluationError("constant literal value required", value.sourceref)
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, SubCall):
if isinstance(expr.target, SymbolName): # 'function(1,2,3)'
funcname = expr.target.name
if funcname in math_functions or funcname in builtin_functions:
func_args = []
for a in (_process_constant_expression(callarg.value, sourceref) for callarg in list(expr.arguments.nodes)):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref)
elif isinstance(expr.target, Dereference): # '[...](1,2,3)'
raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref)
elif isinstance(expr.target, LiteralValue) and type(expr.target.value) is int: # '64738()'
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else:
raise NotImplementedError("weird call target", expr.target)
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = _process_constant_expression(expr.left, left_sourceref)
expr.left.parent = expr
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = _process_constant_expression(expr.left, left_sourceref)
expr.left.parent = expr
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = _process_constant_expression(expr.right, right_sourceref)
expr.right.parent = expr
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
else:
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
def _process_dynamic_expression(expr: Expression, sourceref: SourceRef) -> Expression:
# constant-fold a dynamic expression
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
if expr.is_compile_constant():
try:
return _process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
elif isinstance(expr, AddressOf):
if expr.is_compile_constant():
try:
return _process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
elif isinstance(expr, SubCall):
try:
return _process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr
elif isinstance(expr, (Register, Dereference)):
return expr
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = _process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr
if expr.is_compile_constant():
try:
return _process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = _process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = _process_dynamic_expression(expr.right, right_sourceref)
expr.right.parent = expr
if expr.is_compile_constant():
try:
return _process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
else:
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
def optimize(mod: Module) -> None: def optimize(mod: Module) -> None:
opt = Optimizer(mod) opt = Optimizer(mod)
opt.optimize() opt.optimize()

View File

@ -472,7 +472,8 @@ class AddressOf(Expression):
def is_compile_constant(self) -> bool: def is_compile_constant(self) -> bool:
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable # address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY or getattr(symdef, "zp_address", None) is not None # type: ignore return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \
or getattr(symdef, "zp_address", None) is not None # type: ignore
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)