much improved constant folding by actually evaluating const expressions

This commit is contained in:
Irmen de Jong 2018-02-07 02:10:52 +01:00
parent 938c541cc2
commit e0628c7814
5 changed files with 178 additions and 37 deletions

View File

@ -32,14 +32,14 @@ class Optimizer:
def _optimize(self) -> None:
self.constant_folding()
# @todo expression optimization: reduce expression nesting
# @todo expression optimization: reduce expression nesting / flattening of parenthesis
# @todo expression optimization: simplify logical expression when a term makes it always true or false
# @todo expression optimization: optimize some simple multiplications into shifts (A*=8 -> A<<3)
self.create_aug_assignments()
self.optimize_assignments()
self.remove_superfluous_assignments()
self.combine_assignments_into_multi()
self.optimize_multiassigns()
# @todo optimize some simple multiplications into shifts (A*=8 -> A<<3)
# @todo optimize addition with self into shift 1 (A+=A -> A<<=1)
self.optimize_goto_compare_with_zero()
self.join_incrdecrs()

View File

@ -63,6 +63,7 @@ tokens = (
"SHIFTRIGHT",
"LOGICAND",
"LOGICOR",
"LOGICXOR",
"LOGICNOT",
"INTEGERDIVIDE",
"MODULO",
@ -123,6 +124,7 @@ reserved = {
"not": "LOGICNOT",
"and": "LOGICAND",
"or": "LOGICOR",
"xor": "LOGICXOR",
"mod": "MODULO",
"AX": "REGISTER",
"AY": "REGISTER",

View File

@ -32,8 +32,9 @@ class ZpOptions(enum.Enum):
math_functions = {name: func for name, func in vars(math).items()
if inspect.isbuiltin(func) and name != "pow" and not name.startswith("_")}
builtin_functions = {name: func for name, func in vars(builtins).items()
if inspect.isbuiltin(func) and not name.startswith("_")}
builtin_functions = {name: getattr(builtins, name)
for name in ['abs', 'bin', 'chr', 'divmod', 'hash', 'hex', 'len', 'oct', 'ord', 'pow', 'round']}
# @todo support more builtins 'all', 'any', 'max', 'min', 'sum'
class ParseError(Exception):
@ -469,7 +470,9 @@ class AddressOf(Expression):
name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str))
def is_compile_constant(self) -> bool:
return False
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY or symdef.zp_address is not None # type: ignore
def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name)
@ -595,10 +598,75 @@ class ExpressionWithOperator(Expression):
self.operator = "%" # change it back to the more common '%'
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("an expression is not a constant", self)
cv = [n.const_value() for n in self.nodes] # type: ignore
if self.unary:
if self.operator == "-":
return -cv[0]
elif self.operator == "+":
return cv[0]
elif self.operator == "~":
return ~cv[0]
elif self.operator == "not":
return not cv[0]
elif self.operator == "&":
raise TypeError("the address-of operator should have been parsed into an AddressOf node")
else:
raise ValueError("invalid unary operator: "+self.operator, self.sourceref)
else:
if self.operator == "-":
return cv[0] - cv[1]
elif self.operator == "+":
return cv[0] + cv[1]
elif self.operator == "*":
return cv[0] * cv[1]
elif self.operator == "/":
return cv[0] / cv[1]
elif self.operator == "**":
return cv[0] ** cv[1]
elif self.operator == "//":
return cv[0] // cv[1]
elif self.operator in ("%", "mod"):
return cv[0] % cv[1]
elif self.operator == "<<":
return cv[0] << cv[1]
elif self.operator == ">>":
return cv[0] >> cv[1]
elif self.operator == "|":
return cv[0] | cv[1]
elif self.operator == "&":
return cv[0] & cv[1]
elif self.operator == "^":
return cv[0] ^ cv[1]
elif self.operator == "==":
return cv[0] == cv[1]
elif self.operator == "!=":
return cv[0] != cv[1]
elif self.operator == "<":
return cv[0] < cv[1]
elif self.operator == ">":
return cv[0] > cv[1]
elif self.operator == "<=":
return cv[0] <= cv[1]
elif self.operator == ">=":
return cv[0] >= cv[1]
elif self.operator == "and":
return cv[0] and cv[1]
elif self.operator == "or":
return cv[0] or cv[1]
elif self.operator == "xor":
i1 = 1 if cv[0] else 0
i2 = 1 if cv[1] else 0
return bool(i1 ^ i2)
else:
raise ValueError("invalid operator: "+self.operator, self.sourceref)
@no_type_check
def is_compile_constant(self) -> bool:
return False
if len(self.nodes) == 1:
return self.nodes[0].is_compile_constant()
elif len(self.nodes) == 2:
return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant()
raise ValueError("should have 1 or 2 nodes")
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
# make sure the lvalue and rvalue are primitives, and the operator is allowed
@ -667,9 +735,19 @@ class SubCall(Expression):
return self.nodes[2] # type: ignore
def is_compile_constant(self) -> bool:
if isinstance(self.nodes[0], SymbolName):
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
if isinstance(symdef, BuiltinFunction):
return True
return False
@no_type_check
def const_value(self) -> Union[int, float, bool, str]:
if isinstance(self.nodes[0], SymbolName):
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
if isinstance(symdef, BuiltinFunction):
arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes]
return symdef.func(*arguments)
raise TypeError("subroutine call is not a constant value", self)
@ -1439,6 +1517,7 @@ def p_aug_assignment(p):
precedence = (
# following the python operator precedence rules mostly; https://docs.python.org/3/reference/expressions.html#operator-precedence
('left', 'LOGICOR'),
('left', 'LOGICXOR'),
('left', 'LOGICAND'),
('right', 'LOGICNOT'),
('left', "LT", "GT", "LE", "GE", "EQUALS", "NOTEQUALS"),
@ -1468,6 +1547,7 @@ def p_expression(p):
| expression SHIFTRIGHT expression
| expression LOGICOR expression
| expression LOGICAND expression
| expression LOGICXOR expression
| expression POWER expression
| expression INTEGERDIVIDE expression
| expression LT expression

View File

@ -2,7 +2,7 @@ import math
import pytest
from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef
from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \
VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\
AstNode, Expression, Assignment, VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\
BuiltinFunction, UndefinedSymbolError
from il65.datatypes import DataType, VarType
@ -11,6 +11,15 @@ def lexer_error(sourceref: SourceRef, fmtstring: str, *args: str) -> None:
print("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
def parse_source(src: str) -> AstNode:
lexer.lineno = 1
lexer.source_filename = "sourcefile"
tfilt = TokenFilter(lexer)
result = parser.parse(input=src, tokenfunc=tfilt.token)
connect_parents(result, None)
return result
lexer.error_function = lexer_error
@ -112,11 +121,7 @@ def test_tokenfilter():
def test_parser():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_1, tokenfunc=filter.token)
connect_parents(result, None)
result = parse_source(test_source_1)
assert isinstance(result, Module)
assert result.name == "sourcefile"
assert result.scope.name == "<sourcefile global scope>"
@ -169,11 +174,7 @@ test_source_2 = """
def test_parser_2():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_2, tokenfunc=filter.token)
connect_parents(result, None)
result = parse_source(test_source_2)
block = result.scope.nodes[0]
call = block.scope.nodes[0]
assert isinstance(call, SubCall)
@ -198,11 +199,7 @@ test_source_3 = """
def test_typespec():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_3, tokenfunc=filter.token)
connect_parents(result, None)
result = parse_source(test_source_3)
block = result.scope.nodes[0]
assignment1, assignment2, assignment3, assignment4 = block.scope.nodes
assert assignment1.right.value == 5
@ -252,11 +249,7 @@ test_source_4 = """
def test_char_string():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_4, tokenfunc=filter.token)
connect_parents(result, None)
result = parse_source(test_source_4)
block = result.scope.nodes[0]
var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes
assert var1.value.value == 64
@ -279,11 +272,7 @@ test_source_5 = """
def test_boolean_int():
lexer.lineno = 1
lexer.source_filename = "sourcefile"
filter = TokenFilter(lexer)
result = parser.parse(input=test_source_5, tokenfunc=filter.token)
connect_parents(result, None)
result = parse_source(test_source_5)
block = result.scope.nodes[0]
var1, var2, assgn1, assgn2, = block.scope.nodes
assert type(var1.value.value) is int and var1.value.value == 1
@ -355,9 +344,9 @@ def test_symbol_lookup():
math_func = scope_inner.lookup("sin")
assert isinstance(math_func, BuiltinFunction)
assert math_func.name == "sin" and math_func.func is math.sin
builtin_func = scope_inner.lookup("max")
builtin_func = scope_inner.lookup("abs")
assert isinstance(builtin_func, BuiltinFunction)
assert builtin_func.name == "max" and builtin_func.func is max
assert builtin_func.name == "abs" and builtin_func.func is abs
# test dotted names:
with pytest.raises(UndefinedSymbolError):
scope_inner.lookup("noscope.nosymbol.nothing")
@ -367,3 +356,74 @@ def test_symbol_lookup():
with pytest.raises(UndefinedSymbolError):
scope_inner.lookup("outer.var2")
assert scope_inner.lookup("outer.var1") is var1
def test_const_numeric_expressions():
src = """
~ {
A = 1+2+3+4+5
X = 1+2*5+2
Y = (1+2)*(5+2)
A = (((10+20)/2)+5)**3
X = -10-11-12
Y = 1.234 mod (0.9 / 1.2)
A = sin(1.234)
X = round(4.567)-2
Y = 1+abs(-100)
A = ~1
X = -1
A = 4 << (9-3)
X = 5000 >> 2
Y = 999//88
A = &sin ; type error
}
"""
result = parse_source(src)
if isinstance(result, Module):
result.scope.define_builtin_functions()
assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments]
assert e[0].const_value() == 15 # 1+2+3+4+5
assert e[1].const_value() == 13 # 1+2*5+2
assert e[2].const_value() == 21 # (1+2)*(5+2)
assert e[3].const_value() == 8000 # (((10+20)/2)+5)**3
assert e[4].const_value() == -33 # -10-11-12
assert e[5].const_value() == 0.484 # 1.234 mod (0.9 / 1.2)
assert math.isclose(e[6].const_value(), 0.9438182093746337) # sin(1.234)
assert e[7].const_value() == 3 # round(4.567)-2
assert e[8].const_value() == 101 # 1+abs(-100)
assert e[9].const_value() == -2 # ~1
assert e[10].const_value() == -1 # -1
assert e[11].const_value() == 256 # 4 << (9-3)
assert e[12].const_value() == 1250 # 5000 >> 2
assert e[13].const_value() == 11 # 999//88
with pytest.raises(TypeError):
e[14].const_value()
def test_const_logic_expressions():
src = """
~ {
A = true or false
X = true and false
Y = true xor false
A = false and false or true
X = (false and (false or true))
Y = not (false or true)
A = 1 < 2
X = 1 >= 2
Y = 1 == (2+3)
}
"""
result = parse_source(src)
assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments]
assert e[0].const_value() == True
assert e[1].const_value() == False
assert e[2].const_value() == True
assert e[3].const_value() == True
assert e[4].const_value() == False
assert e[5].const_value() == False
assert e[6].const_value() == True
assert e[7].const_value() == False
assert e[8].const_value() == False

View File

@ -98,8 +98,7 @@ def test_const_value():
e = ExpressionWithOperator(operator="-", sourceref=sref)
e.left = LiteralValue(value=42, sourceref=sref)
v.value = e
with pytest.raises(TypeError):
v.const_value()
assert v.const_value() == -42
s = SymbolName(name="unexisting", sourceref=sref)
s.parent = scope
v.value = s