From de3bca076368d00ad0bc42434e4cb8c736bc3dcd Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Sun, 18 Feb 2018 18:19:51 +0100 Subject: [PATCH] const folding cleanups and explicit notion of assignment LHS --- il65/compile.py | 14 ++-- il65/constantfold.py | 60 +++++++-------- il65/emit/assignment.py | 2 +- il65/emit/calls.py | 15 ++-- il65/emit/generate.py | 11 ++- il65/optimize.py | 48 +++++++----- il65/plyparse.py | 116 +++++++++++----------------- tests/test_parser.py | 162 +++++++++++++++++++++++++++------------- tests/test_vardef.py | 9 --- todo.ill | 1 + 10 files changed, 234 insertions(+), 204 deletions(-) diff --git a/il65/compile.py b/il65/compile.py index 0b86b820b..060fd4541 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -93,21 +93,21 @@ class PlyParser: # note: not processing regular assignments, because they can contain multiple targets of different datatype. # this has to be dealt with anyway later, so we don't bother dealing with it here for just a special case. if isinstance(node, AugAssignment): - if node.right.is_compile_constant(): + if node.right.is_compiletime_const(): _, node.right = coerce_constant_value(datatype_of(node.left, node.my_scope()), node.right, node.right.sourceref) elif isinstance(node, Goto): - if node.condition is not None and node.condition.is_compile_constant(): + if node.condition is not None and node.condition.is_compiletime_const(): _, node.nodes[1] = coerce_constant_value(DataType.WORD, node.nodes[1], node.nodes[1].sourceref) # type: ignore elif isinstance(node, Return): - if node.value_A is not None and node.value_A.is_compile_constant(): + if node.value_A is not None and node.value_A.is_compiletime_const(): _, node.nodes[0] = coerce_constant_value(DataType.BYTE, node.nodes[0], node.nodes[0].sourceref) # type: ignore - if node.value_X is not None and node.value_X.is_compile_constant(): + if node.value_X is not None and node.value_X.is_compiletime_const(): _, node.nodes[1] = coerce_constant_value(DataType.BYTE, node.nodes[1], node.nodes[1].sourceref) # type: ignore - if node.value_Y is not None and node.value_Y.is_compile_constant(): + if node.value_Y is not None and node.value_Y.is_compiletime_const(): _, node.nodes[2] = coerce_constant_value(DataType.BYTE, node.nodes[2], node.nodes[2].sourceref) # type: ignore elif isinstance(node, VarDef): if node.value is not None: - if node.value.is_compile_constant(): + if node.value.is_compiletime_const(): _, node.value = coerce_constant_value(datatype_of(node, node.my_scope()), node.value, node.value.sourceref) except OverflowError as x: raise ParseError(str(x), node.sourceref) @@ -189,7 +189,7 @@ class PlyParser: if isinstance(node.right, LiteralValue) and node.right.value == 0: raise ParseError("division by zero", node.right.sourceref) elif isinstance(node, VarDef): - if node.value is not None and not node.value.is_compile_constant(): + if node.value is not None and not node.value.is_compiletime_const(): raise ParseError("variable initialization value should be a compile-time constant", node.value.sourceref) elif isinstance(node, Dereference): if isinstance(node.operand, Register) and node.operand.datatype == DataType.BYTE: diff --git a/il65/constantfold.py b/il65/constantfold.py index 887afc953..904fd9e7a 100644 --- a/il65/constantfold.py +++ b/il65/constantfold.py @@ -62,8 +62,12 @@ class ConstantFold: def _process_expression(self, expr: Expression) -> Expression: # process/simplify all expressions (constant folding etc) + if expr.is_lhs: + if isinstance(expr, (Register, SymbolName, Dereference)): + return expr + raise ParseError("invalid lhs expression type", expr.sourceref) result = None # type: Expression - if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant: + if expr.is_compiletime_const(): result = self._process_constant_expression(expr, expr.sourceref) else: result = self._process_dynamic_expression(expr, expr.sourceref) @@ -74,9 +78,11 @@ class ConstantFold: # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. if isinstance(expr, LiteralValue): return expr - if expr.is_compile_constant(): + try: return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore - elif isinstance(expr, SymbolName): + except NotCompiletimeConstantError: + pass + if isinstance(expr, SymbolName): value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: @@ -169,26 +175,24 @@ class ConstantFold: # constant-fold a dynamic expression if isinstance(expr, LiteralValue): return expr - if expr.is_compile_constant(): + try: return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore - elif isinstance(expr, SymbolName): - if expr.is_compile_constant(): - try: - return self._process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - pass - return expr + except NotCompiletimeConstantError: + pass + if isinstance(expr, SymbolName): + try: + return self._process_constant_expression(expr, sourceref) + except (ExpressionEvaluationError, NotCompiletimeConstantError): + return expr elif isinstance(expr, AddressOf): - if expr.is_compile_constant(): - try: - return self._process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - pass - return expr + try: + return self._process_constant_expression(expr, sourceref) + except (ExpressionEvaluationError, NotCompiletimeConstantError): + return expr elif isinstance(expr, SubCall): try: return self._process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: + except (ExpressionEvaluationError, NotCompiletimeConstantError): if isinstance(expr.target, SymbolName): check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref) return expr @@ -199,12 +203,10 @@ class ConstantFold: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_dynamic_expression(expr.left, left_sourceref) expr.left.parent = expr - if expr.is_compile_constant(): - try: - return self._process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - pass - return expr + try: + return self._process_constant_expression(expr, sourceref) + except (ExpressionEvaluationError, NotCompiletimeConstantError): + return expr else: left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref expr.left = self._process_dynamic_expression(expr.left, left_sourceref) @@ -212,11 +214,9 @@ class ConstantFold: right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref expr.right = self._process_dynamic_expression(expr.right, right_sourceref) expr.right.parent = expr - if expr.is_compile_constant(): - try: - return self._process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - pass - return expr + try: + return self._process_constant_expression(expr, sourceref) + except (ExpressionEvaluationError, NotCompiletimeConstantError): + return expr else: raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) diff --git a/il65/emit/assignment.py b/il65/emit/assignment.py index af1ddb1ad..56b75cf7e 100644 --- a/il65/emit/assignment.py +++ b/il65/emit/assignment.py @@ -16,7 +16,7 @@ def generate_assignment(ctx: Context) -> None: assert isinstance(ctx.stmt, Assignment) assert not isinstance(ctx.stmt.right, Assignment), "assignment should have been flattened" ctx.out("\v\t\t\t; " + ctx.stmt.lineref) - ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left.nodes, ctx.stmt.right)) + ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left, ctx.stmt.right)) # @todo assignment diff --git a/il65/emit/calls.py b/il65/emit/calls.py index b0ade54f7..9b28d8e31 100644 --- a/il65/emit/calls.py +++ b/il65/emit/calls.py @@ -15,9 +15,9 @@ def generate_goto(ctx: Context) -> None: ctx.out("\v\t\t\t; " + ctx.stmt.lineref) if stmt.condition: if stmt.if_stmt: - _gen_goto_special_if_cond(ctx, stmt) + _gen_goto_cond(ctx, stmt, "true") else: - _gen_goto_cond(ctx, stmt) + _gen_goto_cond(ctx, stmt, stmt.if_cond) else: if stmt.if_stmt: _gen_goto_special_if(ctx, stmt) @@ -134,12 +134,11 @@ def _gen_goto_unconditional(ctx: Context, stmt: Goto) -> None: raise CodeError("invalid goto target type", stmt) -def _gen_goto_special_if_cond(ctx: Context, stmt: Goto) -> None: - pass # @todo special if WITH conditional expression - - -def _gen_goto_cond(ctx: Context, stmt: Goto) -> None: - pass # @todo regular if WITH conditional expression +def _gen_goto_cond(ctx: Context, stmt: Goto, if_cond: str) -> None: + if isinstance(stmt.condition, LiteralValue): + pass # @todo if WITH conditional expression + else: + raise CodeError("no support for evaluating conditional expression yet", stmt) # @todo def generate_subcall(ctx: Context) -> None: diff --git a/il65/emit/generate.py b/il65/emit/generate.py index f4f8ff01f..31561ba33 100644 --- a/il65/emit/generate.py +++ b/il65/emit/generate.py @@ -10,7 +10,7 @@ import datetime from typing import TextIO, Callable, no_type_check from ..plylex import print_bold from ..plyparse import (Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, ZpOptions, - InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr, AssignmentTargets) + InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr) from . import CodeError, to_hex, to_mflpt5, Context from .variables import generate_block_init, generate_block_vars from .assignment import generate_assignment, generate_aug_assignment @@ -195,22 +195,25 @@ class AssemblyGenerator: if stmt.value_A: reg = Register(name="A", sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref) - assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(reg) assignment.nodes.append(stmt.value_A) + assignment.mark_lhs() ctx.stmt = assignment generate_assignment(ctx) if stmt.value_X: reg = Register(name="X", sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref) - assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(reg) assignment.nodes.append(stmt.value_X) + assignment.mark_lhs() ctx.stmt = assignment generate_assignment(ctx) if stmt.value_Y: reg = Register(name="Y", sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref) - assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) + assignment.nodes.append(reg) assignment.nodes.append(stmt.value_Y) + assignment.mark_lhs() ctx.stmt = assignment generate_assignment(ctx) ctx.out("\vrts") diff --git a/il65/optimize.py b/il65/optimize.py index 2f422eedf..1d3681be5 100644 --- a/il65/optimize.py +++ b/il65/optimize.py @@ -7,7 +7,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0 """ from typing import List, no_type_check, Union -from .datatypes import DataType +from .datatypes import DataType, VarType from .plyparse import * from .plylex import print_warning, print_bold from .constantfold import ConstantFold @@ -117,8 +117,13 @@ class Optimizer: return True if isinstance(node1, SymbolName) and isinstance(node2, SymbolName) and node1.name == node2.name: return True - if isinstance(node1, Dereference) and isinstance(node2, Dereference) and node1.operand == node2.operand: - return True + if isinstance(node1, Dereference) and isinstance(node2, Dereference): + if type(node1.operand) is not type(node2.operand): + return False + if isinstance(node1.operand, (SymbolName, LiteralValue, Register)): + return node1.operand == node2.operand + if not isinstance(node1, AstNode) or not isinstance(node2, AstNode): + raise TypeError("same_target called with invalid type(s)", node1, node2) return False @no_type_check @@ -132,7 +137,7 @@ class Optimizer: continue expr = assignment.right if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators - if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): + if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left): num_val = expr.right.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) @@ -141,13 +146,13 @@ class Optimizer: continue if expr.operator not in ('+', '*', '|', '^'): # associative operators continue - if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): + if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left): num_val = expr.right.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) self.optimizations_performed = True - elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right): + elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.right): num_val = expr.left.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) @@ -161,12 +166,16 @@ class Optimizer: prev_node = None # type: AstNode for node in list(scope.nodes): if isinstance(node, Assignment) and isinstance(prev_node, Assignment): - if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left): - if not node.left.has_memvalue(): - scope.remove_node(prev_node) - self.optimizations_performed = True - self.num_warnings += 1 - print_warning("{}: removed superfluous assignment".format(prev_node.sourceref)) + if isinstance(node.right, (LiteralValue, Register)) and self._same_target(node.left, prev_node.left): + if isinstance(node.left, SymbolName): + # only optimize if the symbol is not a memory mapped address (volatile memory!) + symdef = node.left.my_scope().lookup(node.left.name) + if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY: + continue + scope.remove_node(prev_node) + self.optimizations_performed = True + self.num_warnings += 1 + print_warning("{}: removed superfluous assignment".format(prev_node.sourceref)) prev_node = node @no_type_check @@ -178,17 +187,17 @@ class Optimizer: # @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?) for assignment in self.module.all_nodes(): if isinstance(assignment, Assignment): - if all(lv == assignment.right for lv in assignment.left.nodes): + if self._same_target(assignment.left, assignment.right): assignment.my_scope().remove_node(assignment) self.optimizations_performed = True self.num_warnings += 1 - print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) + print_warning("{}: removed statement that has no effect (left=right)".format(assignment.sourceref)) elif isinstance(assignment, AugAssignment): if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)): if assignment.right.value == 0: if assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="): self.num_warnings += 1 - print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) + print_warning("{}: removed statement that has no effect (aug.assign zero)".format(assignment.sourceref)) assignment.my_scope().remove_node(assignment) self.optimizations_performed = True elif assignment.operator == "*=": @@ -206,9 +215,10 @@ class Optimizer: elif assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="): print("{}: shifting result is always zero".format(assignment.sourceref)) new_stmt = Assignment(sourceref=assignment.sourceref) - new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref)) + new_stmt.nodes.append(assignment.left) new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref)) assignment.my_scope().replace_node(assignment, new_stmt) + assignment.mark_lhs() self.optimizations_performed = True elif assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256: howmuch = assignment.right @@ -223,7 +233,7 @@ class Optimizer: self.optimizations_performed = True elif assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="): self.num_warnings += 1 - print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) + print_warning("{}: removed statement that has no effect (aug.assign identity)".format(assignment.sourceref)) assignment.my_scope().remove_node(assignment) self.optimizations_performed = True @@ -231,12 +241,13 @@ class Optimizer: def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment: new_assignment = Assignment(sourceref=old_aug_assignment.sourceref) new_assignment.parent = old_aug_assignment.parent - left = AssignmentTargets(nodes=[old_aug_assignment.left], sourceref=old_aug_assignment.sourceref) + left = old_aug_assignment.left left.parent = new_assignment new_assignment.nodes.append(left) value = LiteralValue(value=constantvalue, sourceref=old_aug_assignment.sourceref) value.parent = new_assignment new_assignment.nodes.append(value) + new_assignment.mark_lhs() return new_assignment @no_type_check @@ -250,6 +261,7 @@ class Optimizer: a.nodes.append(lv) lv.parent = a a.parent = old_assign.parent + a.mark_lhs() return a @no_type_check diff --git a/il65/plyparse.py b/il65/plyparse.py index 7eabc62d9..ea80d90cb 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -22,8 +22,8 @@ __all__ = ["ProgramFormat", "ZpOptions", "math_functions", "builtin_functions", "UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression", "Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr", "ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment", - "InlineAssembly", "AssignmentTargets", "BuiltinFunction", "TokenFilter", "parser", "connect_parents", - "parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition"] + "InlineAssembly", "BuiltinFunction", "TokenFilter", "parser", "connect_parents", + "parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition", "NotCompiletimeConstantError"] class ProgramFormat(enum.Enum): @@ -55,6 +55,10 @@ class ParseError(Exception): return "{} {:s}".format(self.sourceref, self.args[0]) +class NotCompiletimeConstantError(TypeError): + pass + + class ExpressionEvaluationError(ParseError): pass @@ -348,7 +352,9 @@ class Label(AstNode): class Expression(AstNode): # just a common base class for the nodes that are an expression themselves: # ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference - def is_compile_constant(self) -> bool: + is_lhs = attr.ib(type=bool, init=False, default=False) # left hand side of incrdecr/assignment/augassign? + + def is_compiletime_const(self) -> bool: raise NotImplementedError("implement in subclass") def const_value(self) -> Union[int, float, bool, str]: @@ -382,11 +388,11 @@ class Register(Expression): return NotImplemented return self.name < other.name - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: return False def const_value(self) -> Union[int, float, bool, str]: - raise TypeError("register doesn't have a constant numeric value", self) + raise NotCompiletimeConstantError("register doesn't have a constant numeric value", self) @attr.s(cmp=False) @@ -464,7 +470,7 @@ class LiteralValue(Expression): def const_value(self) -> Union[int, float, bool, str]: return self.value - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: return True @@ -473,7 +479,7 @@ class AddressOf(Expression): # no subnodes. name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str)) - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: # address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable symdef = self.my_scope().lookup(self.name) return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \ @@ -486,16 +492,16 @@ class AddressOf(Expression): return symdef.zp_address if symdef.vartype == VarType.MEMORY: return symdef.value.const_value() - raise TypeError("can only take constant address of a memory mapped variable", self) - raise TypeError("should be a vardef to be able to take its address", self) + raise NotCompiletimeConstantError("can only take constant address of a memory mapped variable", self) + raise NotCompiletimeConstantError("should be a vardef to be able to take its address", self) -@attr.s(cmp=False, slots=True) +@attr.s(cmp=True, slots=True) class SymbolName(Expression): # no subnodes. name = attr.ib(type=str) - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: symdef = self.my_scope().lookup(self.name) return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST @@ -503,7 +509,7 @@ class SymbolName(Expression): symdef = self.my_scope().lookup(self.name) if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: return symdef.const_value() - raise TypeError("should be a const vardef to be able to take its constant numeric value", self) + raise NotCompiletimeConstantError("should be a const vardef to be able to take its constant numeric value", self) @attr.s(cmp=False) @@ -531,11 +537,11 @@ class Dereference(Expression): if self.nodes and not isinstance(self.nodes[0], (SymbolName, LiteralValue, Register)): raise TypeError("operand of dereference invalid type", self.nodes[0], self.sourceref) - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: return False def const_value(self) -> Union[int, float, bool, str]: - raise TypeError("dereference is not a constant numeric value") + raise NotCompiletimeConstantError("dereference is not a constant numeric value") @attr.s(cmp=False) @@ -556,6 +562,8 @@ class IncrDecr(AstNode): raise ParseError("cannot incr/decr that register", self.sourceref) assert isinstance(target, (Register, SymbolName, Dereference)) self.nodes.clear() + # the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement. + target.is_lhs = True self.nodes.append(target) def __attrs_post_init__(self): @@ -569,8 +577,6 @@ class IncrDecr(AstNode): class ExpressionWithOperator(Expression): # 2 nodes: left (Expression), right (not present if unary, Expression if not unary) operator = attr.ib(type=str) - # when evaluating the expression, does it have to be a compile-time constant value? - must_be_constant = attr.ib(type=bool, init=False, default=False) @property def unary(self) -> bool: @@ -612,7 +618,7 @@ class ExpressionWithOperator(Expression): elif self.operator == "not": return not cv[0] elif self.operator == "&": - raise TypeError("the address-of operator should have been parsed into an AddressOf node") + raise NotCompiletimeConstantError("the address-of operator should have been parsed into an AddressOf node") else: raise ValueError("invalid unary operator: "+self.operator, self.sourceref) else: @@ -664,11 +670,11 @@ class ExpressionWithOperator(Expression): raise ValueError("invalid operator: "+self.operator, self.sourceref) @no_type_check - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: if len(self.nodes) == 1: - return self.nodes[0].is_compile_constant() + return self.nodes[0].is_compiletime_const() elif len(self.nodes) == 2: - return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant() + return self.nodes[0].is_compiletime_const() and self.nodes[1].is_compiletime_const() raise ValueError("should have 1 or 2 nodes") def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue: @@ -742,7 +748,7 @@ class SubCall(Expression): def arguments(self) -> CallArguments: return self.nodes[2] # type: ignore - def is_compile_constant(self) -> bool: + def is_compiletime_const(self) -> bool: if isinstance(self.nodes[0], SymbolName): symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name) if isinstance(symdef, BuiltinFunction): @@ -756,7 +762,7 @@ class SubCall(Expression): if isinstance(symdef, BuiltinFunction): arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes] return symdef.func(*arguments) - raise TypeError("subroutine call is not a constant value", self) + raise NotCompiletimeConstantError("subroutine call is not a constant value", self) @attr.s(cmp=False, slots=True, repr=False) @@ -779,13 +785,10 @@ class VarDef(AstNode): self.nodes[0] = value else: self.nodes.append(value) - if isinstance(value, ExpressionWithOperator): - # an expression in a vardef should evaluate to a compile-time constant: - value.must_be_constant = True def const_value(self) -> Union[int, float, bool, str]: if self.vartype != VarType.CONST: - raise TypeError("not a constant value", self) + raise NotCompiletimeConstantError("not a constant value", self) if self.nodes and isinstance(self.nodes[0], Expression): return self.nodes[0].const_value() raise ValueError("no value", self) @@ -841,59 +844,14 @@ class Return(AstNode): return self.nodes[2] if len(self.nodes) >= 3 else None # type: ignore -@attr.s(cmp=False, slots=True, repr=False) -class AssignmentTargets(AstNode): - # a list of one or more assignment targets (Register, SymbolName, or Dereference). - nodes = attr.ib(type=list, init=True) # requires nodes in __init__ - - def has_memvalue(self) -> bool: - for t in self.nodes: - if isinstance(t, Dereference): - return True - if isinstance(t, SymbolName): - symdef = self.my_scope().lookup(t.name) - if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY: - return True - return False - - def same_targets(self, other: 'AssignmentTargets') -> bool: - if len(self.nodes) != len(other.nodes): - return False - # @todo be able to compare targets in different order as well (sort them) - for t1, t2 in zip(self.nodes, other.nodes): - if type(t1) is not type(t2): - return False - if isinstance(t1, Register): - if t1 != t2: # __eq__ is defined - return False - elif isinstance(t1, SymbolName): - if t1.name != t2.name: - return False - elif isinstance(t1, Dereference): - if t1.size != t2.size or t1.datatype != t2.datatype: - return False - op1, op2 = t1.operand, t2.operand - if type(op1) is not type(op2): - return False - if isinstance(op1, SymbolName): - if op1.name != op2.name: - return False - else: - if op1 != op2: - return False - else: - return False - return True - - @attr.s(cmp=False, slots=True, repr=False) class Assignment(AstNode): # can be single- or multi-assignment - # has two subnodes: left (=AssignmentTargets) and right (=Expression, + # has two subnodes: left (=Register/SymbolName/Dereference ) and right (=Expression, # or another Assignment but those will be converted into multi assign) @property - def left(self) -> AssignmentTargets: + def left(self) -> Union[Register, SymbolName, Dereference]: return self.nodes[0] # type: ignore @property @@ -905,6 +863,10 @@ class Assignment(AstNode): assert isinstance(rvalue, Expression) self.nodes[1] = rvalue + def mark_lhs(self): + # the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement. + self.nodes[0].is_lhs = True + @attr.s(cmp=False, slots=True, repr=False) class AugAssignment(AstNode): @@ -924,6 +886,10 @@ class AugAssignment(AstNode): assert isinstance(rvalue, Expression) self.nodes[1] = rvalue + def mark_lhs(self): + # the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement. + self.nodes[0].is_lhs = True + def datatype_of(targetnode: AstNode, scope: Scope) -> DataType: # tries to determine the DataType of an assignment target node @@ -1512,8 +1478,9 @@ def p_assignment(p): | assignment_target IS assignment """ p[0] = Assignment(sourceref=_token_sref(p, 2)) - p[0].nodes.append(AssignmentTargets(nodes=[p[1]], sourceref=p[0].sourceref)) + p[0].nodes.append(p[1]) p[0].nodes.append(p[3]) + p[0].mark_lhs() def p_aug_assignment(p): @@ -1523,6 +1490,7 @@ def p_aug_assignment(p): p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2)) p[0].nodes.append(p[1]) p[0].nodes.append(p[3]) + p[0].mark_lhs() precedence = ( diff --git a/tests/test_parser.py b/tests/test_parser.py index 0f5e2bf25..34c2f69cb 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -164,16 +164,14 @@ def test_block_nodes(): assert len(sub2.scope.nodes) > 0 -test_source_2 = """ -~ { - 999(1,2) - [zz]() -} -""" - - def test_parser_2(): - result = parse_source(test_source_2) + src = """ + ~ { + 999(1,2) + [zz]() + } + """ + result = parse_source(src) block = result.scope.nodes[0] call = block.scope.nodes[0] assert isinstance(call, SubCall) @@ -187,18 +185,16 @@ def test_parser_2(): assert call.target.operand.name == "zz" -test_source_3 = """ -~ { - [$c000.word] = 5 - [$c000 .byte] = 5 - [AX .word] = 5 - [AX .float] = 5 -} -""" - - def test_typespec(): - result = parse_source(test_source_3) + src = """ + ~ { + [$c000.word] = 5 + [$c000 .byte] = 5 + [AX .word] = 5 + [AX .float] = 5 + } + """ + result = parse_source(src) block = result.scope.nodes[0] assignment1, assignment2, assignment3, assignment4 = block.scope.nodes assert assignment1.right.value == 5 @@ -209,10 +205,10 @@ def test_typespec(): assert len(assignment2.left.nodes) == 1 assert len(assignment3.left.nodes) == 1 assert len(assignment4.left.nodes) == 1 - t1 = assignment1.left.nodes[0] - t2 = assignment2.left.nodes[0] - t3 = assignment3.left.nodes[0] - t4 = assignment4.left.nodes[0] + t1 = assignment1.left + t2 = assignment2.left + t3 = assignment3.left + t4 = assignment4.left assert isinstance(t1, Dereference) assert isinstance(t2, Dereference) assert isinstance(t3, Dereference) @@ -235,20 +231,18 @@ def test_typespec(): assert t4.size is None -test_source_4 = """ -~ { - var x1 = '@' - var x2 = 'π' - var x3 = 'abc' - A = '@' - A = 'π' - A = 'abc' -} -""" - - def test_char_string(): - result = parse_source(test_source_4) + src = """ + ~ { + var x1 = '@' + var x2 = 'π' + var x3 = 'abc' + A = '@' + A = 'π' + A = 'abc' + } + """ + result = parse_source(src) block = result.scope.nodes[0] var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes assert var1.value.value == '@' @@ -260,18 +254,16 @@ def test_char_string(): # note: the actual one-charactor-to-bytevalue conversion is done at the very latest, when issuing an assignment statement -test_source_5 = """ -~ { - var x1 = true - var x2 = false - A = true - A = false -} -""" - - def test_boolean_int(): - result = parse_source(test_source_5) + src = """ + ~ { + var x1 = true + var x2 = false + A = true + A = false + } + """ + result = parse_source(src) block = result.scope.nodes[0] var1, var2, assgn1, assgn2, = block.scope.nodes assert type(var1.value.value) is int and var1.value.value == 1 @@ -381,7 +373,7 @@ def test_const_numeric_expressions(): result.scope.define_builtin_functions() assignments = list(result.all_nodes(Assignment)) e = [a.nodes[1] for a in assignments] - assert all(x.is_compile_constant() for x in e) + assert all(x.is_compiletime_const() for x in e) assert e[0].const_value() == 15 # 1+2+3+4+5 assert e[1].const_value() == 13 # 1+2*5+2 assert e[2].const_value() == 21 # (1+2)*(5+2) @@ -415,7 +407,7 @@ def test_const_logic_expressions(): result = parse_source(src) assignments = list(result.all_nodes(Assignment)) e = [a.nodes[1] for a in assignments] - assert all(x.is_compile_constant() for x in e) + assert all(x.is_compiletime_const() for x in e) assert e[0].const_value() == True assert e[1].const_value() == False assert e[2].const_value() == True @@ -441,12 +433,12 @@ def test_const_other_expressions(): result.scope.define_builtin_functions() assignments = list(result.all_nodes(Assignment)) e = [a.nodes[1] for a in assignments] - assert e[0].is_compile_constant() + assert e[0].is_compiletime_const() assert e[0].const_value() == 0xc123 - assert not e[1].is_compile_constant() + assert not e[1].is_compiletime_const() with pytest.raises(TypeError): e[1].const_value() - assert not e[2].is_compile_constant() + assert not e[2].is_compiletime_const() with pytest.raises(TypeError): e[2].const_value() @@ -495,3 +487,67 @@ def test_vdef_const_folds(): assert vd[2].datatype == DataType.BYTE assert isinstance(vd[2].value, LiteralValue) assert vd[2].value.value == 369 + + +def test_vdef_const_expressions(): + src = """ +~ { + var bvar = 99 + var .float fvar = sin(1.2-0.3) + var .float flt2 = -9.87e-6 + + bvar ++ + fvar ++ + flt2 ++ + bvar += 2+2 + fvar += 2+3 + flt2 += 2+4 + bvar = 2+5 + fvar = 2+6 + flt2 = 2+7 +} +""" + result = parse_source(src) + if isinstance(result, Module): + result.scope.define_builtin_functions() + cf = ConstantFold(result) + cf.fold_constants() + vd = list(result.all_nodes(VarDef)) + assert len(vd)==3 + assert vd[0].name == "bvar" + assert isinstance(vd[0].value, LiteralValue) + assert vd[0].value.value == 99 + assert vd[1].name == "fvar" + assert isinstance(vd[1].value, LiteralValue) + assert type(vd[1].value.value) is float + assert math.isclose(vd[1].value.value, math.sin(0.9)) + assert vd[2].name == "flt2" + assert isinstance(vd[2].value, LiteralValue) + assert math.isclose(-9.87e-6, vd[2].value.value) + # test incrdecr assignment target + nodes = list(result.all_nodes(IncrDecr)) + assert len(nodes) == 3 + assert isinstance(nodes[0].target, SymbolName) + assert nodes[0].target.name == "bvar" + assert isinstance(nodes[1].target, SymbolName) + assert nodes[1].target.name == "fvar" + assert isinstance(nodes[2].target, SymbolName) + assert nodes[2].target.name == "flt2" + # test augassign assignment target + nodes = list(result.all_nodes(AugAssignment)) + assert len(nodes) == 3 + assert isinstance(nodes[0].left, SymbolName) + assert nodes[0].left.name == "bvar" + assert isinstance(nodes[1].left, SymbolName) + assert nodes[1].left.name == "fvar" + assert isinstance(nodes[2].left, SymbolName) + assert nodes[2].left.name == "flt2" + # test assign assignment target + nodes = list(result.all_nodes(Assignment)) + assert len(nodes) == 3 + assert isinstance(nodes[0].left, SymbolName) + assert nodes[0].left.name == "bvar" + assert isinstance(nodes[1].left, SymbolName) + assert nodes[1].left.name == "fvar" + assert isinstance(nodes[2].left, SymbolName) + assert nodes[2].left.name == "flt2" diff --git a/tests/test_vardef.py b/tests/test_vardef.py index f71a7ed07..9e55b3244 100644 --- a/tests/test_vardef.py +++ b/tests/test_vardef.py @@ -3,13 +3,6 @@ from il65.datatypes import DataType from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError from il65.plylex import SourceRef -# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.). -# name = attr.ib(type=str) -# vartype = attr.ib() -# datatype = attr.ib() -# size = attr.ib(type=list, default=None) -# zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later - def test_creation(): sref = SourceRef("test", 1, 1) @@ -62,10 +55,8 @@ def test_set_value(): assert v.value.value == "hello" e = ExpressionWithOperator(operator="-", sourceref=sref) e.left = LiteralValue(value=42, sourceref=sref) - assert not e.must_be_constant v.value = e assert v.value is e - assert e.must_be_constant def test_const_value(): diff --git a/todo.ill b/todo.ill index 980da2b63..01d7f0bd5 100644 --- a/todo.ill +++ b/todo.ill @@ -1,4 +1,5 @@ ~ main { + var .float flt1 = 9.87e-21 var .float flt = -9.87e-21 const .word border = $0099 var counter = 1