From e0628c7814d1857ab0a74e603c7018d06b579fb9 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Wed, 7 Feb 2018 02:10:52 +0100 Subject: [PATCH] much improved constant folding by actually evaluating const expressions --- il65/optimize.py | 4 +- il65/plylex.py | 2 + il65/plyparse.py | 90 +++++++++++++++++++++++++++++++-- tests/test_parser.py | 116 ++++++++++++++++++++++++++++++++----------- tests/test_vardef.py | 3 +- 5 files changed, 178 insertions(+), 37 deletions(-) diff --git a/il65/optimize.py b/il65/optimize.py index 53ad9083a..945f179c0 100644 --- a/il65/optimize.py +++ b/il65/optimize.py @@ -32,14 +32,14 @@ class Optimizer: def _optimize(self) -> None: self.constant_folding() - # @todo expression optimization: reduce expression nesting + # @todo expression optimization: reduce expression nesting / flattening of parenthesis # @todo expression optimization: simplify logical expression when a term makes it always true or false + # @todo expression optimization: optimize some simple multiplications into shifts (A*=8 -> A<<3) self.create_aug_assignments() self.optimize_assignments() self.remove_superfluous_assignments() self.combine_assignments_into_multi() self.optimize_multiassigns() - # @todo optimize some simple multiplications into shifts (A*=8 -> A<<3) # @todo optimize addition with self into shift 1 (A+=A -> A<<=1) self.optimize_goto_compare_with_zero() self.join_incrdecrs() diff --git a/il65/plylex.py b/il65/plylex.py index dcba4fced..9dc4f8128 100644 --- a/il65/plylex.py +++ b/il65/plylex.py @@ -63,6 +63,7 @@ tokens = ( "SHIFTRIGHT", "LOGICAND", "LOGICOR", + "LOGICXOR", "LOGICNOT", "INTEGERDIVIDE", "MODULO", @@ -123,6 +124,7 @@ reserved = { "not": "LOGICNOT", "and": "LOGICAND", "or": "LOGICOR", + "xor": "LOGICXOR", "mod": "MODULO", "AX": "REGISTER", "AY": "REGISTER", diff --git a/il65/plyparse.py b/il65/plyparse.py index ddf78b535..5d1ba45bf 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -32,8 +32,9 @@ class ZpOptions(enum.Enum): math_functions = {name: func for name, func in vars(math).items() if inspect.isbuiltin(func) and name != "pow" and not name.startswith("_")} -builtin_functions = {name: func for name, func in vars(builtins).items() - if inspect.isbuiltin(func) and not name.startswith("_")} +builtin_functions = {name: getattr(builtins, name) + for name in ['abs', 'bin', 'chr', 'divmod', 'hash', 'hex', 'len', 'oct', 'ord', 'pow', 'round']} +# @todo support more builtins 'all', 'any', 'max', 'min', 'sum' class ParseError(Exception): @@ -469,7 +470,9 @@ class AddressOf(Expression): name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str)) def is_compile_constant(self) -> bool: - return False + # address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable + symdef = self.my_scope().lookup(self.name) + return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY or symdef.zp_address is not None # type: ignore def const_value(self) -> Union[int, float, bool, str]: symdef = self.my_scope().lookup(self.name) @@ -595,10 +598,75 @@ class ExpressionWithOperator(Expression): self.operator = "%" # change it back to the more common '%' def const_value(self) -> Union[int, float, bool, str]: - raise TypeError("an expression is not a constant", self) + cv = [n.const_value() for n in self.nodes] # type: ignore + if self.unary: + if self.operator == "-": + return -cv[0] + elif self.operator == "+": + return cv[0] + elif self.operator == "~": + return ~cv[0] + elif self.operator == "not": + return not cv[0] + elif self.operator == "&": + raise TypeError("the address-of operator should have been parsed into an AddressOf node") + else: + raise ValueError("invalid unary operator: "+self.operator, self.sourceref) + else: + if self.operator == "-": + return cv[0] - cv[1] + elif self.operator == "+": + return cv[0] + cv[1] + elif self.operator == "*": + return cv[0] * cv[1] + elif self.operator == "/": + return cv[0] / cv[1] + elif self.operator == "**": + return cv[0] ** cv[1] + elif self.operator == "//": + return cv[0] // cv[1] + elif self.operator in ("%", "mod"): + return cv[0] % cv[1] + elif self.operator == "<<": + return cv[0] << cv[1] + elif self.operator == ">>": + return cv[0] >> cv[1] + elif self.operator == "|": + return cv[0] | cv[1] + elif self.operator == "&": + return cv[0] & cv[1] + elif self.operator == "^": + return cv[0] ^ cv[1] + elif self.operator == "==": + return cv[0] == cv[1] + elif self.operator == "!=": + return cv[0] != cv[1] + elif self.operator == "<": + return cv[0] < cv[1] + elif self.operator == ">": + return cv[0] > cv[1] + elif self.operator == "<=": + return cv[0] <= cv[1] + elif self.operator == ">=": + return cv[0] >= cv[1] + elif self.operator == "and": + return cv[0] and cv[1] + elif self.operator == "or": + return cv[0] or cv[1] + elif self.operator == "xor": + i1 = 1 if cv[0] else 0 + i2 = 1 if cv[1] else 0 + return bool(i1 ^ i2) + else: + raise ValueError("invalid operator: "+self.operator, self.sourceref) + @no_type_check def is_compile_constant(self) -> bool: - return False + if len(self.nodes) == 1: + return self.nodes[0].is_compile_constant() + elif len(self.nodes) == 2: + return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant() + raise ValueError("should have 1 or 2 nodes") def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue: # make sure the lvalue and rvalue are primitives, and the operator is allowed @@ -667,9 +735,19 @@ class SubCall(Expression): return self.nodes[2] # type: ignore def is_compile_constant(self) -> bool: + if isinstance(self.nodes[0], SymbolName): + symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name) + if isinstance(symdef, BuiltinFunction): + return True return False + @no_type_check def const_value(self) -> Union[int, float, bool, str]: + if isinstance(self.nodes[0], SymbolName): + symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name) + if isinstance(symdef, BuiltinFunction): + arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes] + return symdef.func(*arguments) raise TypeError("subroutine call is not a constant value", self) @@ -1439,6 +1517,7 @@ def p_aug_assignment(p): precedence = ( # following the python operator precedence rules mostly; https://docs.python.org/3/reference/expressions.html#operator-precedence ('left', 'LOGICOR'), + ('left', 'LOGICXOR'), ('left', 'LOGICAND'), ('right', 'LOGICNOT'), ('left', "LT", "GT", "LE", "GE", "EQUALS", "NOTEQUALS"), @@ -1468,6 +1547,7 @@ def p_expression(p): | expression SHIFTRIGHT expression | expression LOGICOR expression | expression LOGICAND expression + | expression LOGICXOR expression | expression POWER expression | expression INTEGERDIVIDE expression | expression LT expression diff --git a/tests/test_parser.py b/tests/test_parser.py index 5fa2d3547..e3621337b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2,7 +2,7 @@ import math import pytest from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \ - VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\ + AstNode, Expression, Assignment, VarDef, Register, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference,\ BuiltinFunction, UndefinedSymbolError from il65.datatypes import DataType, VarType @@ -11,6 +11,15 @@ def lexer_error(sourceref: SourceRef, fmtstring: str, *args: str) -> None: print("ERROR: {}: {}".format(sourceref, fmtstring.format(*args))) +def parse_source(src: str) -> AstNode: + lexer.lineno = 1 + lexer.source_filename = "sourcefile" + tfilt = TokenFilter(lexer) + result = parser.parse(input=src, tokenfunc=tfilt.token) + connect_parents(result, None) + return result + + lexer.error_function = lexer_error @@ -112,11 +121,7 @@ def test_tokenfilter(): def test_parser(): - lexer.lineno = 1 - lexer.source_filename = "sourcefile" - filter = TokenFilter(lexer) - result = parser.parse(input=test_source_1, tokenfunc=filter.token) - connect_parents(result, None) + result = parse_source(test_source_1) assert isinstance(result, Module) assert result.name == "sourcefile" assert result.scope.name == "" @@ -169,11 +174,7 @@ test_source_2 = """ def test_parser_2(): - lexer.lineno = 1 - lexer.source_filename = "sourcefile" - filter = TokenFilter(lexer) - result = parser.parse(input=test_source_2, tokenfunc=filter.token) - connect_parents(result, None) + result = parse_source(test_source_2) block = result.scope.nodes[0] call = block.scope.nodes[0] assert isinstance(call, SubCall) @@ -198,11 +199,7 @@ test_source_3 = """ def test_typespec(): - lexer.lineno = 1 - lexer.source_filename = "sourcefile" - filter = TokenFilter(lexer) - result = parser.parse(input=test_source_3, tokenfunc=filter.token) - connect_parents(result, None) + result = parse_source(test_source_3) block = result.scope.nodes[0] assignment1, assignment2, assignment3, assignment4 = block.scope.nodes assert assignment1.right.value == 5 @@ -252,11 +249,7 @@ test_source_4 = """ def test_char_string(): - lexer.lineno = 1 - lexer.source_filename = "sourcefile" - filter = TokenFilter(lexer) - result = parser.parse(input=test_source_4, tokenfunc=filter.token) - connect_parents(result, None) + result = parse_source(test_source_4) block = result.scope.nodes[0] var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes assert var1.value.value == 64 @@ -279,11 +272,7 @@ test_source_5 = """ def test_boolean_int(): - lexer.lineno = 1 - lexer.source_filename = "sourcefile" - filter = TokenFilter(lexer) - result = parser.parse(input=test_source_5, tokenfunc=filter.token) - connect_parents(result, None) + result = parse_source(test_source_5) block = result.scope.nodes[0] var1, var2, assgn1, assgn2, = block.scope.nodes assert type(var1.value.value) is int and var1.value.value == 1 @@ -355,9 +344,9 @@ def test_symbol_lookup(): math_func = scope_inner.lookup("sin") assert isinstance(math_func, BuiltinFunction) assert math_func.name == "sin" and math_func.func is math.sin - builtin_func = scope_inner.lookup("max") + builtin_func = scope_inner.lookup("abs") assert isinstance(builtin_func, BuiltinFunction) - assert builtin_func.name == "max" and builtin_func.func is max + assert builtin_func.name == "abs" and builtin_func.func is abs # test dotted names: with pytest.raises(UndefinedSymbolError): scope_inner.lookup("noscope.nosymbol.nothing") @@ -367,3 +356,74 @@ def test_symbol_lookup(): with pytest.raises(UndefinedSymbolError): scope_inner.lookup("outer.var2") assert scope_inner.lookup("outer.var1") is var1 + + +def test_const_numeric_expressions(): + src = """ +~ { + A = 1+2+3+4+5 + X = 1+2*5+2 + Y = (1+2)*(5+2) + A = (((10+20)/2)+5)**3 + X = -10-11-12 + Y = 1.234 mod (0.9 / 1.2) + A = sin(1.234) + X = round(4.567)-2 + Y = 1+abs(-100) + A = ~1 + X = -1 + A = 4 << (9-3) + X = 5000 >> 2 + Y = 999//88 + A = &sin ; type error +} +""" + result = parse_source(src) + if isinstance(result, Module): + result.scope.define_builtin_functions() + assignments = list(result.all_nodes(Assignment)) + e = [a.nodes[1] for a in assignments] + assert e[0].const_value() == 15 # 1+2+3+4+5 + assert e[1].const_value() == 13 # 1+2*5+2 + assert e[2].const_value() == 21 # (1+2)*(5+2) + assert e[3].const_value() == 8000 # (((10+20)/2)+5)**3 + assert e[4].const_value() == -33 # -10-11-12 + assert e[5].const_value() == 0.484 # 1.234 mod (0.9 / 1.2) + assert math.isclose(e[6].const_value(), 0.9438182093746337) # sin(1.234) + assert e[7].const_value() == 3 # round(4.567)-2 + assert e[8].const_value() == 101 # 1+abs(-100) + assert e[9].const_value() == -2 # ~1 + assert e[10].const_value() == -1 # -1 + assert e[11].const_value() == 256 # 4 << (9-3) + assert e[12].const_value() == 1250 # 5000 >> 2 + assert e[13].const_value() == 11 # 999//88 + with pytest.raises(TypeError): + e[14].const_value() + + +def test_const_logic_expressions(): + src = """ +~ { + A = true or false + X = true and false + Y = true xor false + A = false and false or true + X = (false and (false or true)) + Y = not (false or true) + A = 1 < 2 + X = 1 >= 2 + Y = 1 == (2+3) +} +""" + result = parse_source(src) + assignments = list(result.all_nodes(Assignment)) + e = [a.nodes[1] for a in assignments] + assert e[0].const_value() == True + assert e[1].const_value() == False + assert e[2].const_value() == True + assert e[3].const_value() == True + assert e[4].const_value() == False + assert e[5].const_value() == False + assert e[6].const_value() == True + assert e[7].const_value() == False + assert e[8].const_value() == False diff --git a/tests/test_vardef.py b/tests/test_vardef.py index eaf171f2e..8682e370f 100644 --- a/tests/test_vardef.py +++ b/tests/test_vardef.py @@ -98,8 +98,7 @@ def test_const_value(): e = ExpressionWithOperator(operator="-", sourceref=sref) e.left = LiteralValue(value=42, sourceref=sref) v.value = e - with pytest.raises(TypeError): - v.const_value() + assert v.const_value() == -42 s = SymbolName(name="unexisting", sourceref=sref) s.parent = scope v.value = s