mirror of
https://github.com/irmen/prog8.git
synced 2025-01-11 13:29:45 +00:00
much improved constant folding by actually evaluating const expressions
This commit is contained in:
parent
938c541cc2
commit
7b2af25a42
@ -32,17 +32,19 @@ class Optimizer:
|
||||
|
||||
def _optimize(self) -> None:
|
||||
self.constant_folding()
|
||||
# @todo expression optimization: reduce expression nesting
|
||||
# @todo expression optimization: reduce expression nesting / flattening of parenthesis
|
||||
# @todo expression optimization: simplify logical expression when a term makes it always true or false
|
||||
# @todo expression optimization: optimize some simple multiplications into shifts (A*=8 -> A<<3)
|
||||
self.create_aug_assignments()
|
||||
self.optimize_assignments()
|
||||
self.remove_superfluous_assignments()
|
||||
self.combine_assignments_into_multi()
|
||||
self.optimize_multiassigns()
|
||||
# @todo optimize some simple multiplications into shifts (A*=8 -> A<<3)
|
||||
# @todo optimize addition with self into shift 1 (A+=A -> A<<=1)
|
||||
self.optimize_goto_compare_with_zero()
|
||||
self.join_incrdecrs()
|
||||
# @todo remove gotos with conditions that are always false
|
||||
# @todo remove loops with conditions that are always empty/false
|
||||
# @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to)
|
||||
|
||||
def handle_internal_error(self, exc: Exception, msg: str="") -> None:
|
||||
|
@ -63,6 +63,7 @@ tokens = (
|
||||
"SHIFTRIGHT",
|
||||
"LOGICAND",
|
||||
"LOGICOR",
|
||||
"LOGICXOR",
|
||||
"LOGICNOT",
|
||||
"INTEGERDIVIDE",
|
||||
"MODULO",
|
||||
@ -123,6 +124,7 @@ reserved = {
|
||||
"not": "LOGICNOT",
|
||||
"and": "LOGICAND",
|
||||
"or": "LOGICOR",
|
||||
"xor": "LOGICXOR",
|
||||
"mod": "MODULO",
|
||||
"AX": "REGISTER",
|
||||
"AY": "REGISTER",
|
||||
|
@ -32,8 +32,9 @@ class ZpOptions(enum.Enum):
|
||||
|
||||
math_functions = {name: func for name, func in vars(math).items()
|
||||
if inspect.isbuiltin(func) and name != "pow" and not name.startswith("_")}
|
||||
builtin_functions = {name: func for name, func in vars(builtins).items()
|
||||
if inspect.isbuiltin(func) and not name.startswith("_")}
|
||||
builtin_functions = {name: getattr(builtins, name)
|
||||
for name in ['abs', 'bin', 'chr', 'divmod', 'hash', 'hex', 'len', 'oct', 'ord', 'pow', 'round']}
|
||||
# @todo support more builtins 'all', 'any', 'max', 'min', 'sum'
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
@ -469,7 +470,9 @@ class AddressOf(Expression):
|
||||
name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str))
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
return False
|
||||
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
|
||||
symdef = self.my_scope().lookup(self.name)
|
||||
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY or getattr(symdef, "zp_address", None) is not None # type: ignore
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
symdef = self.my_scope().lookup(self.name)
|
||||
@ -595,10 +598,75 @@ class ExpressionWithOperator(Expression):
|
||||
self.operator = "%" # change it back to the more common '%'
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
raise TypeError("an expression is not a constant", self)
|
||||
cv = [n.const_value() for n in self.nodes] # type: ignore
|
||||
if self.unary:
|
||||
if self.operator == "-":
|
||||
return -cv[0]
|
||||
elif self.operator == "+":
|
||||
return cv[0]
|
||||
elif self.operator == "~":
|
||||
return ~cv[0]
|
||||
elif self.operator == "not":
|
||||
return not cv[0]
|
||||
elif self.operator == "&":
|
||||
raise TypeError("the address-of operator should have been parsed into an AddressOf node")
|
||||
else:
|
||||
raise ValueError("invalid unary operator: "+self.operator, self.sourceref)
|
||||
else:
|
||||
if self.operator == "-":
|
||||
return cv[0] - cv[1]
|
||||
elif self.operator == "+":
|
||||
return cv[0] + cv[1]
|
||||
elif self.operator == "*":
|
||||
return cv[0] * cv[1]
|
||||
elif self.operator == "/":
|
||||
return cv[0] / cv[1]
|
||||
elif self.operator == "**":
|
||||
return cv[0] ** cv[1]
|
||||
elif self.operator == "//":
|
||||
return cv[0] // cv[1]
|
||||
elif self.operator in ("%", "mod"):
|
||||
return cv[0] % cv[1]
|
||||
elif self.operator == "<<":
|
||||
return cv[0] << cv[1]
|
||||
elif self.operator == ">>":
|
||||
return cv[0] >> cv[1]
|
||||
elif self.operator == "|":
|
||||
return cv[0] | cv[1]
|
||||
elif self.operator == "&":
|
||||
return cv[0] & cv[1]
|
||||
elif self.operator == "^":
|
||||
return cv[0] ^ cv[1]
|
||||
elif self.operator == "==":
|
||||
return cv[0] == cv[1]
|
||||
elif self.operator == "!=":
|
||||
return cv[0] != cv[1]
|
||||
elif self.operator == "<":
|
||||
return cv[0] < cv[1]
|
||||
elif self.operator == ">":
|
||||
return cv[0] > cv[1]
|
||||
elif self.operator == "<=":
|
||||
return cv[0] <= cv[1]
|
||||
elif self.operator == ">=":
|
||||
return cv[0] >= cv[1]
|
||||
elif self.operator == "and":
|
||||
return cv[0] and cv[1]
|
||||
elif self.operator == "or":
|
||||
return cv[0] or cv[1]
|
||||
elif self.operator == "xor":
|
||||
i1 = 1 if cv[0] else 0
|
||||
i2 = 1 if cv[1] else 0
|
||||
return bool(i1 ^ i2)
|
||||
else:
|
||||
raise ValueError("invalid operator: "+self.operator, self.sourceref)
|
||||
|
||||
@no_type_check
|
||||
def is_compile_constant(self) -> bool:
|
||||
return False
|
||||
if len(self.nodes) == 1:
|
||||
return self.nodes[0].is_compile_constant()
|
||||
elif len(self.nodes) == 2:
|
||||
return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant()
|
||||
raise ValueError("should have 1 or 2 nodes")
|
||||
|
||||
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
|
||||
# make sure the lvalue and rvalue are primitives, and the operator is allowed
|
||||
@ -667,9 +735,19 @@ class SubCall(Expression):
|
||||
return self.nodes[2] # type: ignore
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
if isinstance(self.nodes[0], SymbolName):
|
||||
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
|
||||
if isinstance(symdef, BuiltinFunction):
|
||||
return True
|
||||
return False
|
||||
|
||||
@no_type_check
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
if isinstance(self.nodes[0], SymbolName):
|
||||
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
|
||||
if isinstance(symdef, BuiltinFunction):
|
||||
arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes]
|
||||
return symdef.func(*arguments)
|
||||
raise TypeError("subroutine call is not a constant value", self)
|
||||
|
||||
|
||||
@ -1439,6 +1517,7 @@ def p_aug_assignment(p):
|
||||
precedence = (
|
||||
# following the python operator precedence rules mostly; https://docs.python.org/3/reference/expressions.html#operator-precedence
|
||||
('left', 'LOGICOR'),
|
||||
('left', 'LOGICXOR'),
|
||||
('left', 'LOGICAND'),
|
||||
('right', 'LOGICNOT'),
|
||||
('left', "LT", "GT", "LE", "GE", "EQUALS", "NOTEQUALS"),
|
||||
@ -1468,6 +1547,7 @@ def p_expression(p):
|
||||
| expression SHIFTRIGHT expression
|
||||
| expression LOGICOR expression
|
||||
| expression LOGICAND expression
|
||||
| expression LOGICXOR expression
|
||||
| expression POWER expression
|
||||
| expression INTEGERDIVIDE expression
|
||||
| expression LT expression
|
||||
|
@ -2,7 +2,7 @@ import math
|
||||
import pytest
|
||||
from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef
|
||||
from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \
|
||||
VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\
|
||||
AstNode, Expression, Assignment, VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\
|
||||
BuiltinFunction, UndefinedSymbolError
|
||||
from il65.datatypes import DataType, VarType
|
||||
|
||||
@ -11,6 +11,15 @@ def lexer_error(sourceref: SourceRef, fmtstring: str, *args: str) -> None:
|
||||
print("ERROR: {}: {}".format(sourceref, fmtstring.format(*args)))
|
||||
|
||||
|
||||
def parse_source(src: str) -> AstNode:
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
tfilt = TokenFilter(lexer)
|
||||
result = parser.parse(input=src, tokenfunc=tfilt.token)
|
||||
connect_parents(result, None)
|
||||
return result
|
||||
|
||||
|
||||
lexer.error_function = lexer_error
|
||||
|
||||
|
||||
@ -112,11 +121,7 @@ def test_tokenfilter():
|
||||
|
||||
|
||||
def test_parser():
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
filter = TokenFilter(lexer)
|
||||
result = parser.parse(input=test_source_1, tokenfunc=filter.token)
|
||||
connect_parents(result, None)
|
||||
result = parse_source(test_source_1)
|
||||
assert isinstance(result, Module)
|
||||
assert result.name == "sourcefile"
|
||||
assert result.scope.name == "<sourcefile global scope>"
|
||||
@ -169,11 +174,7 @@ test_source_2 = """
|
||||
|
||||
|
||||
def test_parser_2():
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
filter = TokenFilter(lexer)
|
||||
result = parser.parse(input=test_source_2, tokenfunc=filter.token)
|
||||
connect_parents(result, None)
|
||||
result = parse_source(test_source_2)
|
||||
block = result.scope.nodes[0]
|
||||
call = block.scope.nodes[0]
|
||||
assert isinstance(call, SubCall)
|
||||
@ -198,11 +199,7 @@ test_source_3 = """
|
||||
|
||||
|
||||
def test_typespec():
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
filter = TokenFilter(lexer)
|
||||
result = parser.parse(input=test_source_3, tokenfunc=filter.token)
|
||||
connect_parents(result, None)
|
||||
result = parse_source(test_source_3)
|
||||
block = result.scope.nodes[0]
|
||||
assignment1, assignment2, assignment3, assignment4 = block.scope.nodes
|
||||
assert assignment1.right.value == 5
|
||||
@ -252,11 +249,7 @@ test_source_4 = """
|
||||
|
||||
|
||||
def test_char_string():
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
filter = TokenFilter(lexer)
|
||||
result = parser.parse(input=test_source_4, tokenfunc=filter.token)
|
||||
connect_parents(result, None)
|
||||
result = parse_source(test_source_4)
|
||||
block = result.scope.nodes[0]
|
||||
var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes
|
||||
assert var1.value.value == 64
|
||||
@ -279,11 +272,7 @@ test_source_5 = """
|
||||
|
||||
|
||||
def test_boolean_int():
|
||||
lexer.lineno = 1
|
||||
lexer.source_filename = "sourcefile"
|
||||
filter = TokenFilter(lexer)
|
||||
result = parser.parse(input=test_source_5, tokenfunc=filter.token)
|
||||
connect_parents(result, None)
|
||||
result = parse_source(test_source_5)
|
||||
block = result.scope.nodes[0]
|
||||
var1, var2, assgn1, assgn2, = block.scope.nodes
|
||||
assert type(var1.value.value) is int and var1.value.value == 1
|
||||
@ -355,9 +344,9 @@ def test_symbol_lookup():
|
||||
math_func = scope_inner.lookup("sin")
|
||||
assert isinstance(math_func, BuiltinFunction)
|
||||
assert math_func.name == "sin" and math_func.func is math.sin
|
||||
builtin_func = scope_inner.lookup("max")
|
||||
builtin_func = scope_inner.lookup("abs")
|
||||
assert isinstance(builtin_func, BuiltinFunction)
|
||||
assert builtin_func.name == "max" and builtin_func.func is max
|
||||
assert builtin_func.name == "abs" and builtin_func.func is abs
|
||||
# test dotted names:
|
||||
with pytest.raises(UndefinedSymbolError):
|
||||
scope_inner.lookup("noscope.nosymbol.nothing")
|
||||
@ -367,3 +356,97 @@ def test_symbol_lookup():
|
||||
with pytest.raises(UndefinedSymbolError):
|
||||
scope_inner.lookup("outer.var2")
|
||||
assert scope_inner.lookup("outer.var1") is var1
|
||||
|
||||
|
||||
def test_const_numeric_expressions():
|
||||
src = """
|
||||
~ {
|
||||
A = 1+2+3+4+5
|
||||
X = 1+2*5+2
|
||||
Y = (1+2)*(5+2)
|
||||
A = (((10+20)/2)+5)**3
|
||||
X = -10-11-12
|
||||
Y = 1.234 mod (0.9 / 1.2)
|
||||
A = sin(1.234)
|
||||
X = round(4.567)-2
|
||||
Y = 1+abs(-100)
|
||||
A = ~1
|
||||
X = -1
|
||||
A = 4 << (9-3)
|
||||
X = 5000 >> 2
|
||||
Y = 999//88
|
||||
}
|
||||
"""
|
||||
result = parse_source(src)
|
||||
if isinstance(result, Module):
|
||||
result.scope.define_builtin_functions()
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert all(x.is_compile_constant() for x in e)
|
||||
assert e[0].const_value() == 15 # 1+2+3+4+5
|
||||
assert e[1].const_value() == 13 # 1+2*5+2
|
||||
assert e[2].const_value() == 21 # (1+2)*(5+2)
|
||||
assert e[3].const_value() == 8000 # (((10+20)/2)+5)**3
|
||||
assert e[4].const_value() == -33 # -10-11-12
|
||||
assert e[5].const_value() == 0.484 # 1.234 mod (0.9 / 1.2)
|
||||
assert math.isclose(e[6].const_value(), 0.9438182093746337) # sin(1.234)
|
||||
assert e[7].const_value() == 3 # round(4.567)-2
|
||||
assert e[8].const_value() == 101 # 1+abs(-100)
|
||||
assert e[9].const_value() == -2 # ~1
|
||||
assert e[10].const_value() == -1 # -1
|
||||
assert e[11].const_value() == 256 # 4 << (9-3)
|
||||
assert e[12].const_value() == 1250 # 5000 >> 2
|
||||
assert e[13].const_value() == 11 # 999//88
|
||||
|
||||
|
||||
def test_const_logic_expressions():
|
||||
src = """
|
||||
~ {
|
||||
A = true or false
|
||||
X = true and false
|
||||
Y = true xor false
|
||||
A = false and false or true
|
||||
X = (false and (false or true))
|
||||
Y = not (false or true)
|
||||
A = 1 < 2
|
||||
X = 1 >= 2
|
||||
Y = 1 == (2+3)
|
||||
}
|
||||
"""
|
||||
result = parse_source(src)
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert all(x.is_compile_constant() for x in e)
|
||||
assert e[0].const_value() == True
|
||||
assert e[1].const_value() == False
|
||||
assert e[2].const_value() == True
|
||||
assert e[3].const_value() == True
|
||||
assert e[4].const_value() == False
|
||||
assert e[5].const_value() == False
|
||||
assert e[6].const_value() == True
|
||||
assert e[7].const_value() == False
|
||||
assert e[8].const_value() == False
|
||||
|
||||
|
||||
def test_const_other_expressions():
|
||||
src = """
|
||||
~ {
|
||||
memory memvar = $c123
|
||||
A = &memvar ; constant
|
||||
X = &sin ; non-constant
|
||||
Y = [memvar] ; non-constant
|
||||
}
|
||||
"""
|
||||
result = parse_source(src)
|
||||
if isinstance(result, Module):
|
||||
result.scope.define_builtin_functions()
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert e[0].is_compile_constant()
|
||||
assert e[0].const_value() == 0xc123
|
||||
assert not e[1].is_compile_constant()
|
||||
with pytest.raises(TypeError):
|
||||
e[1].const_value()
|
||||
assert not e[2].is_compile_constant()
|
||||
with pytest.raises(TypeError):
|
||||
e[2].const_value()
|
||||
|
@ -98,8 +98,7 @@ def test_const_value():
|
||||
e = ExpressionWithOperator(operator="-", sourceref=sref)
|
||||
e.left = LiteralValue(value=42, sourceref=sref)
|
||||
v.value = e
|
||||
with pytest.raises(TypeError):
|
||||
v.const_value()
|
||||
assert v.const_value() == -42
|
||||
s = SymbolName(name="unexisting", sourceref=sref)
|
||||
s.parent = scope
|
||||
v.value = s
|
||||
|
Loading…
x
Reference in New Issue
Block a user