const folding cleanups and explicit notion of assignment LHS

This commit is contained in:
Irmen de Jong 2018-02-18 18:19:51 +01:00
parent a171bb998d
commit de3bca0763
10 changed files with 234 additions and 204 deletions

View File

@ -93,21 +93,21 @@ class PlyParser:
# note: not processing regular assignments, because they can contain multiple targets of different datatype. # note: not processing regular assignments, because they can contain multiple targets of different datatype.
# this has to be dealt with anyway later, so we don't bother dealing with it here for just a special case. # this has to be dealt with anyway later, so we don't bother dealing with it here for just a special case.
if isinstance(node, AugAssignment): if isinstance(node, AugAssignment):
if node.right.is_compile_constant(): if node.right.is_compiletime_const():
_, node.right = coerce_constant_value(datatype_of(node.left, node.my_scope()), node.right, node.right.sourceref) _, node.right = coerce_constant_value(datatype_of(node.left, node.my_scope()), node.right, node.right.sourceref)
elif isinstance(node, Goto): elif isinstance(node, Goto):
if node.condition is not None and node.condition.is_compile_constant(): if node.condition is not None and node.condition.is_compiletime_const():
_, node.nodes[1] = coerce_constant_value(DataType.WORD, node.nodes[1], node.nodes[1].sourceref) # type: ignore _, node.nodes[1] = coerce_constant_value(DataType.WORD, node.nodes[1], node.nodes[1].sourceref) # type: ignore
elif isinstance(node, Return): elif isinstance(node, Return):
if node.value_A is not None and node.value_A.is_compile_constant(): if node.value_A is not None and node.value_A.is_compiletime_const():
_, node.nodes[0] = coerce_constant_value(DataType.BYTE, node.nodes[0], node.nodes[0].sourceref) # type: ignore _, node.nodes[0] = coerce_constant_value(DataType.BYTE, node.nodes[0], node.nodes[0].sourceref) # type: ignore
if node.value_X is not None and node.value_X.is_compile_constant(): if node.value_X is not None and node.value_X.is_compiletime_const():
_, node.nodes[1] = coerce_constant_value(DataType.BYTE, node.nodes[1], node.nodes[1].sourceref) # type: ignore _, node.nodes[1] = coerce_constant_value(DataType.BYTE, node.nodes[1], node.nodes[1].sourceref) # type: ignore
if node.value_Y is not None and node.value_Y.is_compile_constant(): if node.value_Y is not None and node.value_Y.is_compiletime_const():
_, node.nodes[2] = coerce_constant_value(DataType.BYTE, node.nodes[2], node.nodes[2].sourceref) # type: ignore _, node.nodes[2] = coerce_constant_value(DataType.BYTE, node.nodes[2], node.nodes[2].sourceref) # type: ignore
elif isinstance(node, VarDef): elif isinstance(node, VarDef):
if node.value is not None: if node.value is not None:
if node.value.is_compile_constant(): if node.value.is_compiletime_const():
_, node.value = coerce_constant_value(datatype_of(node, node.my_scope()), node.value, node.value.sourceref) _, node.value = coerce_constant_value(datatype_of(node, node.my_scope()), node.value, node.value.sourceref)
except OverflowError as x: except OverflowError as x:
raise ParseError(str(x), node.sourceref) raise ParseError(str(x), node.sourceref)
@ -189,7 +189,7 @@ class PlyParser:
if isinstance(node.right, LiteralValue) and node.right.value == 0: if isinstance(node.right, LiteralValue) and node.right.value == 0:
raise ParseError("division by zero", node.right.sourceref) raise ParseError("division by zero", node.right.sourceref)
elif isinstance(node, VarDef): elif isinstance(node, VarDef):
if node.value is not None and not node.value.is_compile_constant(): if node.value is not None and not node.value.is_compiletime_const():
raise ParseError("variable initialization value should be a compile-time constant", node.value.sourceref) raise ParseError("variable initialization value should be a compile-time constant", node.value.sourceref)
elif isinstance(node, Dereference): elif isinstance(node, Dereference):
if isinstance(node.operand, Register) and node.operand.datatype == DataType.BYTE: if isinstance(node.operand, Register) and node.operand.datatype == DataType.BYTE:

View File

@ -62,8 +62,12 @@ class ConstantFold:
def _process_expression(self, expr: Expression) -> Expression: def _process_expression(self, expr: Expression) -> Expression:
# process/simplify all expressions (constant folding etc) # process/simplify all expressions (constant folding etc)
if expr.is_lhs:
if isinstance(expr, (Register, SymbolName, Dereference)):
return expr
raise ParseError("invalid lhs expression type", expr.sourceref)
result = None # type: Expression result = None # type: Expression
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant: if expr.is_compiletime_const():
result = self._process_constant_expression(expr, expr.sourceref) result = self._process_constant_expression(expr, expr.sourceref)
else: else:
result = self._process_dynamic_expression(expr, expr.sourceref) result = self._process_dynamic_expression(expr, expr.sourceref)
@ -74,9 +78,11 @@ class ConstantFold:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, LiteralValue): if isinstance(expr, LiteralValue):
return expr return expr
if expr.is_compile_constant(): try:
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName): except NotCompiletimeConstantError:
pass
if isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef): if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY: if value.vartype == VarType.MEMORY:
@ -169,26 +175,24 @@ class ConstantFold:
# constant-fold a dynamic expression # constant-fold a dynamic expression
if isinstance(expr, LiteralValue): if isinstance(expr, LiteralValue):
return expr return expr
if expr.is_compile_constant(): try:
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName): except NotCompiletimeConstantError:
if expr.is_compile_constant(): pass
if isinstance(expr, SymbolName):
try: try:
return self._process_constant_expression(expr, sourceref) return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except (ExpressionEvaluationError, NotCompiletimeConstantError):
pass
return expr return expr
elif isinstance(expr, AddressOf): elif isinstance(expr, AddressOf):
if expr.is_compile_constant():
try: try:
return self._process_constant_expression(expr, sourceref) return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except (ExpressionEvaluationError, NotCompiletimeConstantError):
pass
return expr return expr
elif isinstance(expr, SubCall): elif isinstance(expr, SubCall):
try: try:
return self._process_constant_expression(expr, sourceref) return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except (ExpressionEvaluationError, NotCompiletimeConstantError):
if isinstance(expr.target, SymbolName): if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref) check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr return expr
@ -199,11 +203,9 @@ class ConstantFold:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_dynamic_expression(expr.left, left_sourceref) expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr expr.left.parent = expr
if expr.is_compile_constant():
try: try:
return self._process_constant_expression(expr, sourceref) return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except (ExpressionEvaluationError, NotCompiletimeConstantError):
pass
return expr return expr
else: else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
@ -212,11 +214,9 @@ class ConstantFold:
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = self._process_dynamic_expression(expr.right, right_sourceref) expr.right = self._process_dynamic_expression(expr.right, right_sourceref)
expr.right.parent = expr expr.right.parent = expr
if expr.is_compile_constant():
try: try:
return self._process_constant_expression(expr, sourceref) return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except (ExpressionEvaluationError, NotCompiletimeConstantError):
pass
return expr return expr
else: else:
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)

View File

@ -16,7 +16,7 @@ def generate_assignment(ctx: Context) -> None:
assert isinstance(ctx.stmt, Assignment) assert isinstance(ctx.stmt, Assignment)
assert not isinstance(ctx.stmt.right, Assignment), "assignment should have been flattened" assert not isinstance(ctx.stmt.right, Assignment), "assignment should have been flattened"
ctx.out("\v\t\t\t; " + ctx.stmt.lineref) ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left.nodes, ctx.stmt.right)) ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left, ctx.stmt.right))
# @todo assignment # @todo assignment

View File

@ -15,9 +15,9 @@ def generate_goto(ctx: Context) -> None:
ctx.out("\v\t\t\t; " + ctx.stmt.lineref) ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
if stmt.condition: if stmt.condition:
if stmt.if_stmt: if stmt.if_stmt:
_gen_goto_special_if_cond(ctx, stmt) _gen_goto_cond(ctx, stmt, "true")
else: else:
_gen_goto_cond(ctx, stmt) _gen_goto_cond(ctx, stmt, stmt.if_cond)
else: else:
if stmt.if_stmt: if stmt.if_stmt:
_gen_goto_special_if(ctx, stmt) _gen_goto_special_if(ctx, stmt)
@ -134,12 +134,11 @@ def _gen_goto_unconditional(ctx: Context, stmt: Goto) -> None:
raise CodeError("invalid goto target type", stmt) raise CodeError("invalid goto target type", stmt)
def _gen_goto_special_if_cond(ctx: Context, stmt: Goto) -> None: def _gen_goto_cond(ctx: Context, stmt: Goto, if_cond: str) -> None:
pass # @todo special if WITH conditional expression if isinstance(stmt.condition, LiteralValue):
pass # @todo if WITH conditional expression
else:
def _gen_goto_cond(ctx: Context, stmt: Goto) -> None: raise CodeError("no support for evaluating conditional expression yet", stmt) # @todo
pass # @todo regular if WITH conditional expression
def generate_subcall(ctx: Context) -> None: def generate_subcall(ctx: Context) -> None:

View File

@ -10,7 +10,7 @@ import datetime
from typing import TextIO, Callable, no_type_check from typing import TextIO, Callable, no_type_check
from ..plylex import print_bold from ..plylex import print_bold
from ..plyparse import (Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, ZpOptions, from ..plyparse import (Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, ZpOptions,
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr, AssignmentTargets) InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr)
from . import CodeError, to_hex, to_mflpt5, Context from . import CodeError, to_hex, to_mflpt5, Context
from .variables import generate_block_init, generate_block_vars from .variables import generate_block_init, generate_block_vars
from .assignment import generate_assignment, generate_aug_assignment from .assignment import generate_assignment, generate_aug_assignment
@ -195,22 +195,25 @@ class AssemblyGenerator:
if stmt.value_A: if stmt.value_A:
reg = Register(name="A", sourceref=stmt.sourceref) reg = Register(name="A", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_A) assignment.nodes.append(stmt.value_A)
assignment.mark_lhs()
ctx.stmt = assignment ctx.stmt = assignment
generate_assignment(ctx) generate_assignment(ctx)
if stmt.value_X: if stmt.value_X:
reg = Register(name="X", sourceref=stmt.sourceref) reg = Register(name="X", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_X) assignment.nodes.append(stmt.value_X)
assignment.mark_lhs()
ctx.stmt = assignment ctx.stmt = assignment
generate_assignment(ctx) generate_assignment(ctx)
if stmt.value_Y: if stmt.value_Y:
reg = Register(name="Y", sourceref=stmt.sourceref) reg = Register(name="Y", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref) assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref)) assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_Y) assignment.nodes.append(stmt.value_Y)
assignment.mark_lhs()
ctx.stmt = assignment ctx.stmt = assignment
generate_assignment(ctx) generate_assignment(ctx)
ctx.out("\vrts") ctx.out("\vrts")

View File

@ -7,7 +7,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
from typing import List, no_type_check, Union from typing import List, no_type_check, Union
from .datatypes import DataType from .datatypes import DataType, VarType
from .plyparse import * from .plyparse import *
from .plylex import print_warning, print_bold from .plylex import print_warning, print_bold
from .constantfold import ConstantFold from .constantfold import ConstantFold
@ -117,8 +117,13 @@ class Optimizer:
return True return True
if isinstance(node1, SymbolName) and isinstance(node2, SymbolName) and node1.name == node2.name: if isinstance(node1, SymbolName) and isinstance(node2, SymbolName) and node1.name == node2.name:
return True return True
if isinstance(node1, Dereference) and isinstance(node2, Dereference) and node1.operand == node2.operand: if isinstance(node1, Dereference) and isinstance(node2, Dereference):
return True if type(node1.operand) is not type(node2.operand):
return False
if isinstance(node1.operand, (SymbolName, LiteralValue, Register)):
return node1.operand == node2.operand
if not isinstance(node1, AstNode) or not isinstance(node2, AstNode):
raise TypeError("same_target called with invalid type(s)", node1, node2)
return False return False
@no_type_check @no_type_check
@ -132,7 +137,7 @@ class Optimizer:
continue continue
expr = assignment.right expr = assignment.right
if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
num_val = expr.right.const_value() num_val = expr.right.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
@ -141,13 +146,13 @@ class Optimizer:
continue continue
if expr.operator not in ('+', '*', '|', '^'): # associative operators if expr.operator not in ('+', '*', '|', '^'): # associative operators
continue continue
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
num_val = expr.right.const_value() num_val = expr.right.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign) assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True self.optimizations_performed = True
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right): elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.right):
num_val = expr.left.const_value() num_val = expr.left.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
@ -161,8 +166,12 @@ class Optimizer:
prev_node = None # type: AstNode prev_node = None # type: AstNode
for node in list(scope.nodes): for node in list(scope.nodes):
if isinstance(node, Assignment) and isinstance(prev_node, Assignment): if isinstance(node, Assignment) and isinstance(prev_node, Assignment):
if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left): if isinstance(node.right, (LiteralValue, Register)) and self._same_target(node.left, prev_node.left):
if not node.left.has_memvalue(): if isinstance(node.left, SymbolName):
# only optimize if the symbol is not a memory mapped address (volatile memory!)
symdef = node.left.my_scope().lookup(node.left.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
continue
scope.remove_node(prev_node) scope.remove_node(prev_node)
self.optimizations_performed = True self.optimizations_performed = True
self.num_warnings += 1 self.num_warnings += 1
@ -178,17 +187,17 @@ class Optimizer:
# @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?) # @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?)
for assignment in self.module.all_nodes(): for assignment in self.module.all_nodes():
if isinstance(assignment, Assignment): if isinstance(assignment, Assignment):
if all(lv == assignment.right for lv in assignment.left.nodes): if self._same_target(assignment.left, assignment.right):
assignment.my_scope().remove_node(assignment) assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True self.optimizations_performed = True
self.num_warnings += 1 self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) print_warning("{}: removed statement that has no effect (left=right)".format(assignment.sourceref))
elif isinstance(assignment, AugAssignment): elif isinstance(assignment, AugAssignment):
if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)): if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)):
if assignment.right.value == 0: if assignment.right.value == 0:
if assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="): if assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="):
self.num_warnings += 1 self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) print_warning("{}: removed statement that has no effect (aug.assign zero)".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment) assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True self.optimizations_performed = True
elif assignment.operator == "*=": elif assignment.operator == "*=":
@ -206,9 +215,10 @@ class Optimizer:
elif assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="): elif assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="):
print("{}: shifting result is always zero".format(assignment.sourceref)) print("{}: shifting result is always zero".format(assignment.sourceref))
new_stmt = Assignment(sourceref=assignment.sourceref) new_stmt = Assignment(sourceref=assignment.sourceref)
new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref)) new_stmt.nodes.append(assignment.left)
new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref)) new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref))
assignment.my_scope().replace_node(assignment, new_stmt) assignment.my_scope().replace_node(assignment, new_stmt)
assignment.mark_lhs()
self.optimizations_performed = True self.optimizations_performed = True
elif assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256: elif assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256:
howmuch = assignment.right howmuch = assignment.right
@ -223,7 +233,7 @@ class Optimizer:
self.optimizations_performed = True self.optimizations_performed = True
elif assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="): elif assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="):
self.num_warnings += 1 self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) print_warning("{}: removed statement that has no effect (aug.assign identity)".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment) assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True self.optimizations_performed = True
@ -231,12 +241,13 @@ class Optimizer:
def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment: def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment:
new_assignment = Assignment(sourceref=old_aug_assignment.sourceref) new_assignment = Assignment(sourceref=old_aug_assignment.sourceref)
new_assignment.parent = old_aug_assignment.parent new_assignment.parent = old_aug_assignment.parent
left = AssignmentTargets(nodes=[old_aug_assignment.left], sourceref=old_aug_assignment.sourceref) left = old_aug_assignment.left
left.parent = new_assignment left.parent = new_assignment
new_assignment.nodes.append(left) new_assignment.nodes.append(left)
value = LiteralValue(value=constantvalue, sourceref=old_aug_assignment.sourceref) value = LiteralValue(value=constantvalue, sourceref=old_aug_assignment.sourceref)
value.parent = new_assignment value.parent = new_assignment
new_assignment.nodes.append(value) new_assignment.nodes.append(value)
new_assignment.mark_lhs()
return new_assignment return new_assignment
@no_type_check @no_type_check
@ -250,6 +261,7 @@ class Optimizer:
a.nodes.append(lv) a.nodes.append(lv)
lv.parent = a lv.parent = a
a.parent = old_assign.parent a.parent = old_assign.parent
a.mark_lhs()
return a return a
@no_type_check @no_type_check

View File

@ -22,8 +22,8 @@ __all__ = ["ProgramFormat", "ZpOptions", "math_functions", "builtin_functions",
"UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression", "UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression",
"Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr", "Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr",
"ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment", "ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment",
"InlineAssembly", "AssignmentTargets", "BuiltinFunction", "TokenFilter", "parser", "connect_parents", "InlineAssembly", "BuiltinFunction", "TokenFilter", "parser", "connect_parents",
"parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition"] "parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition", "NotCompiletimeConstantError"]
class ProgramFormat(enum.Enum): class ProgramFormat(enum.Enum):
@ -55,6 +55,10 @@ class ParseError(Exception):
return "{} {:s}".format(self.sourceref, self.args[0]) return "{} {:s}".format(self.sourceref, self.args[0])
class NotCompiletimeConstantError(TypeError):
pass
class ExpressionEvaluationError(ParseError): class ExpressionEvaluationError(ParseError):
pass pass
@ -348,7 +352,9 @@ class Label(AstNode):
class Expression(AstNode): class Expression(AstNode):
# just a common base class for the nodes that are an expression themselves: # just a common base class for the nodes that are an expression themselves:
# ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference # ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference
def is_compile_constant(self) -> bool: is_lhs = attr.ib(type=bool, init=False, default=False) # left hand side of incrdecr/assignment/augassign?
def is_compiletime_const(self) -> bool:
raise NotImplementedError("implement in subclass") raise NotImplementedError("implement in subclass")
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
@ -382,11 +388,11 @@ class Register(Expression):
return NotImplemented return NotImplemented
return self.name < other.name return self.name < other.name
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
return False return False
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("register doesn't have a constant numeric value", self) raise NotCompiletimeConstantError("register doesn't have a constant numeric value", self)
@attr.s(cmp=False) @attr.s(cmp=False)
@ -464,7 +470,7 @@ class LiteralValue(Expression):
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
return self.value return self.value
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
return True return True
@ -473,7 +479,7 @@ class AddressOf(Expression):
# no subnodes. # no subnodes.
name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str)) name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str))
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable # address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \ return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \
@ -486,16 +492,16 @@ class AddressOf(Expression):
return symdef.zp_address return symdef.zp_address
if symdef.vartype == VarType.MEMORY: if symdef.vartype == VarType.MEMORY:
return symdef.value.const_value() return symdef.value.const_value()
raise TypeError("can only take constant address of a memory mapped variable", self) raise NotCompiletimeConstantError("can only take constant address of a memory mapped variable", self)
raise TypeError("should be a vardef to be able to take its address", self) raise NotCompiletimeConstantError("should be a vardef to be able to take its address", self)
@attr.s(cmp=False, slots=True) @attr.s(cmp=True, slots=True)
class SymbolName(Expression): class SymbolName(Expression):
# no subnodes. # no subnodes.
name = attr.ib(type=str) name = attr.ib(type=str)
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST
@ -503,7 +509,7 @@ class SymbolName(Expression):
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
return symdef.const_value() return symdef.const_value()
raise TypeError("should be a const vardef to be able to take its constant numeric value", self) raise NotCompiletimeConstantError("should be a const vardef to be able to take its constant numeric value", self)
@attr.s(cmp=False) @attr.s(cmp=False)
@ -531,11 +537,11 @@ class Dereference(Expression):
if self.nodes and not isinstance(self.nodes[0], (SymbolName, LiteralValue, Register)): if self.nodes and not isinstance(self.nodes[0], (SymbolName, LiteralValue, Register)):
raise TypeError("operand of dereference invalid type", self.nodes[0], self.sourceref) raise TypeError("operand of dereference invalid type", self.nodes[0], self.sourceref)
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
return False return False
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("dereference is not a constant numeric value") raise NotCompiletimeConstantError("dereference is not a constant numeric value")
@attr.s(cmp=False) @attr.s(cmp=False)
@ -556,6 +562,8 @@ class IncrDecr(AstNode):
raise ParseError("cannot incr/decr that register", self.sourceref) raise ParseError("cannot incr/decr that register", self.sourceref)
assert isinstance(target, (Register, SymbolName, Dereference)) assert isinstance(target, (Register, SymbolName, Dereference))
self.nodes.clear() self.nodes.clear()
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
target.is_lhs = True
self.nodes.append(target) self.nodes.append(target)
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -569,8 +577,6 @@ class IncrDecr(AstNode):
class ExpressionWithOperator(Expression): class ExpressionWithOperator(Expression):
# 2 nodes: left (Expression), right (not present if unary, Expression if not unary) # 2 nodes: left (Expression), right (not present if unary, Expression if not unary)
operator = attr.ib(type=str) operator = attr.ib(type=str)
# when evaluating the expression, does it have to be a compile-time constant value?
must_be_constant = attr.ib(type=bool, init=False, default=False)
@property @property
def unary(self) -> bool: def unary(self) -> bool:
@ -612,7 +618,7 @@ class ExpressionWithOperator(Expression):
elif self.operator == "not": elif self.operator == "not":
return not cv[0] return not cv[0]
elif self.operator == "&": elif self.operator == "&":
raise TypeError("the address-of operator should have been parsed into an AddressOf node") raise NotCompiletimeConstantError("the address-of operator should have been parsed into an AddressOf node")
else: else:
raise ValueError("invalid unary operator: "+self.operator, self.sourceref) raise ValueError("invalid unary operator: "+self.operator, self.sourceref)
else: else:
@ -664,11 +670,11 @@ class ExpressionWithOperator(Expression):
raise ValueError("invalid operator: "+self.operator, self.sourceref) raise ValueError("invalid operator: "+self.operator, self.sourceref)
@no_type_check @no_type_check
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
if len(self.nodes) == 1: if len(self.nodes) == 1:
return self.nodes[0].is_compile_constant() return self.nodes[0].is_compiletime_const()
elif len(self.nodes) == 2: elif len(self.nodes) == 2:
return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant() return self.nodes[0].is_compiletime_const() and self.nodes[1].is_compiletime_const()
raise ValueError("should have 1 or 2 nodes") raise ValueError("should have 1 or 2 nodes")
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue: def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
@ -742,7 +748,7 @@ class SubCall(Expression):
def arguments(self) -> CallArguments: def arguments(self) -> CallArguments:
return self.nodes[2] # type: ignore return self.nodes[2] # type: ignore
def is_compile_constant(self) -> bool: def is_compiletime_const(self) -> bool:
if isinstance(self.nodes[0], SymbolName): if isinstance(self.nodes[0], SymbolName):
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name) symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
if isinstance(symdef, BuiltinFunction): if isinstance(symdef, BuiltinFunction):
@ -756,7 +762,7 @@ class SubCall(Expression):
if isinstance(symdef, BuiltinFunction): if isinstance(symdef, BuiltinFunction):
arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes] arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes]
return symdef.func(*arguments) return symdef.func(*arguments)
raise TypeError("subroutine call is not a constant value", self) raise NotCompiletimeConstantError("subroutine call is not a constant value", self)
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
@ -779,13 +785,10 @@ class VarDef(AstNode):
self.nodes[0] = value self.nodes[0] = value
else: else:
self.nodes.append(value) self.nodes.append(value)
if isinstance(value, ExpressionWithOperator):
# an expression in a vardef should evaluate to a compile-time constant:
value.must_be_constant = True
def const_value(self) -> Union[int, float, bool, str]: def const_value(self) -> Union[int, float, bool, str]:
if self.vartype != VarType.CONST: if self.vartype != VarType.CONST:
raise TypeError("not a constant value", self) raise NotCompiletimeConstantError("not a constant value", self)
if self.nodes and isinstance(self.nodes[0], Expression): if self.nodes and isinstance(self.nodes[0], Expression):
return self.nodes[0].const_value() return self.nodes[0].const_value()
raise ValueError("no value", self) raise ValueError("no value", self)
@ -841,59 +844,14 @@ class Return(AstNode):
return self.nodes[2] if len(self.nodes) >= 3 else None # type: ignore return self.nodes[2] if len(self.nodes) >= 3 else None # type: ignore
@attr.s(cmp=False, slots=True, repr=False)
class AssignmentTargets(AstNode):
# a list of one or more assignment targets (Register, SymbolName, or Dereference).
nodes = attr.ib(type=list, init=True) # requires nodes in __init__
def has_memvalue(self) -> bool:
for t in self.nodes:
if isinstance(t, Dereference):
return True
if isinstance(t, SymbolName):
symdef = self.my_scope().lookup(t.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
return True
return False
def same_targets(self, other: 'AssignmentTargets') -> bool:
if len(self.nodes) != len(other.nodes):
return False
# @todo be able to compare targets in different order as well (sort them)
for t1, t2 in zip(self.nodes, other.nodes):
if type(t1) is not type(t2):
return False
if isinstance(t1, Register):
if t1 != t2: # __eq__ is defined
return False
elif isinstance(t1, SymbolName):
if t1.name != t2.name:
return False
elif isinstance(t1, Dereference):
if t1.size != t2.size or t1.datatype != t2.datatype:
return False
op1, op2 = t1.operand, t2.operand
if type(op1) is not type(op2):
return False
if isinstance(op1, SymbolName):
if op1.name != op2.name:
return False
else:
if op1 != op2:
return False
else:
return False
return True
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
class Assignment(AstNode): class Assignment(AstNode):
# can be single- or multi-assignment # can be single- or multi-assignment
# has two subnodes: left (=AssignmentTargets) and right (=Expression, # has two subnodes: left (=Register/SymbolName/Dereference ) and right (=Expression,
# or another Assignment but those will be converted into multi assign) # or another Assignment but those will be converted into multi assign)
@property @property
def left(self) -> AssignmentTargets: def left(self) -> Union[Register, SymbolName, Dereference]:
return self.nodes[0] # type: ignore return self.nodes[0] # type: ignore
@property @property
@ -905,6 +863,10 @@ class Assignment(AstNode):
assert isinstance(rvalue, Expression) assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue self.nodes[1] = rvalue
def mark_lhs(self):
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
self.nodes[0].is_lhs = True
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
class AugAssignment(AstNode): class AugAssignment(AstNode):
@ -924,6 +886,10 @@ class AugAssignment(AstNode):
assert isinstance(rvalue, Expression) assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue self.nodes[1] = rvalue
def mark_lhs(self):
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
self.nodes[0].is_lhs = True
def datatype_of(targetnode: AstNode, scope: Scope) -> DataType: def datatype_of(targetnode: AstNode, scope: Scope) -> DataType:
# tries to determine the DataType of an assignment target node # tries to determine the DataType of an assignment target node
@ -1512,8 +1478,9 @@ def p_assignment(p):
| assignment_target IS assignment | assignment_target IS assignment
""" """
p[0] = Assignment(sourceref=_token_sref(p, 2)) p[0] = Assignment(sourceref=_token_sref(p, 2))
p[0].nodes.append(AssignmentTargets(nodes=[p[1]], sourceref=p[0].sourceref)) p[0].nodes.append(p[1])
p[0].nodes.append(p[3]) p[0].nodes.append(p[3])
p[0].mark_lhs()
def p_aug_assignment(p): def p_aug_assignment(p):
@ -1523,6 +1490,7 @@ def p_aug_assignment(p):
p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2)) p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2))
p[0].nodes.append(p[1]) p[0].nodes.append(p[1])
p[0].nodes.append(p[3]) p[0].nodes.append(p[3])
p[0].mark_lhs()
precedence = ( precedence = (

View File

@ -164,16 +164,14 @@ def test_block_nodes():
assert len(sub2.scope.nodes) > 0 assert len(sub2.scope.nodes) > 0
test_source_2 = """ def test_parser_2():
src = """
~ { ~ {
999(1,2) 999(1,2)
[zz]() [zz]()
} }
""" """
result = parse_source(src)
def test_parser_2():
result = parse_source(test_source_2)
block = result.scope.nodes[0] block = result.scope.nodes[0]
call = block.scope.nodes[0] call = block.scope.nodes[0]
assert isinstance(call, SubCall) assert isinstance(call, SubCall)
@ -187,7 +185,8 @@ def test_parser_2():
assert call.target.operand.name == "zz" assert call.target.operand.name == "zz"
test_source_3 = """ def test_typespec():
src = """
~ { ~ {
[$c000.word] = 5 [$c000.word] = 5
[$c000 .byte] = 5 [$c000 .byte] = 5
@ -195,10 +194,7 @@ test_source_3 = """
[AX .float] = 5 [AX .float] = 5
} }
""" """
result = parse_source(src)
def test_typespec():
result = parse_source(test_source_3)
block = result.scope.nodes[0] block = result.scope.nodes[0]
assignment1, assignment2, assignment3, assignment4 = block.scope.nodes assignment1, assignment2, assignment3, assignment4 = block.scope.nodes
assert assignment1.right.value == 5 assert assignment1.right.value == 5
@ -209,10 +205,10 @@ def test_typespec():
assert len(assignment2.left.nodes) == 1 assert len(assignment2.left.nodes) == 1
assert len(assignment3.left.nodes) == 1 assert len(assignment3.left.nodes) == 1
assert len(assignment4.left.nodes) == 1 assert len(assignment4.left.nodes) == 1
t1 = assignment1.left.nodes[0] t1 = assignment1.left
t2 = assignment2.left.nodes[0] t2 = assignment2.left
t3 = assignment3.left.nodes[0] t3 = assignment3.left
t4 = assignment4.left.nodes[0] t4 = assignment4.left
assert isinstance(t1, Dereference) assert isinstance(t1, Dereference)
assert isinstance(t2, Dereference) assert isinstance(t2, Dereference)
assert isinstance(t3, Dereference) assert isinstance(t3, Dereference)
@ -235,7 +231,8 @@ def test_typespec():
assert t4.size is None assert t4.size is None
test_source_4 = """ def test_char_string():
src = """
~ { ~ {
var x1 = '@' var x1 = '@'
var x2 = 'π' var x2 = 'π'
@ -245,10 +242,7 @@ test_source_4 = """
A = 'abc' A = 'abc'
} }
""" """
result = parse_source(src)
def test_char_string():
result = parse_source(test_source_4)
block = result.scope.nodes[0] block = result.scope.nodes[0]
var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes
assert var1.value.value == '@' assert var1.value.value == '@'
@ -260,7 +254,8 @@ def test_char_string():
# note: the actual one-charactor-to-bytevalue conversion is done at the very latest, when issuing an assignment statement # note: the actual one-charactor-to-bytevalue conversion is done at the very latest, when issuing an assignment statement
test_source_5 = """ def test_boolean_int():
src = """
~ { ~ {
var x1 = true var x1 = true
var x2 = false var x2 = false
@ -268,10 +263,7 @@ test_source_5 = """
A = false A = false
} }
""" """
result = parse_source(src)
def test_boolean_int():
result = parse_source(test_source_5)
block = result.scope.nodes[0] block = result.scope.nodes[0]
var1, var2, assgn1, assgn2, = block.scope.nodes var1, var2, assgn1, assgn2, = block.scope.nodes
assert type(var1.value.value) is int and var1.value.value == 1 assert type(var1.value.value) is int and var1.value.value == 1
@ -381,7 +373,7 @@ def test_const_numeric_expressions():
result.scope.define_builtin_functions() result.scope.define_builtin_functions()
assignments = list(result.all_nodes(Assignment)) assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments] e = [a.nodes[1] for a in assignments]
assert all(x.is_compile_constant() for x in e) assert all(x.is_compiletime_const() for x in e)
assert e[0].const_value() == 15 # 1+2+3+4+5 assert e[0].const_value() == 15 # 1+2+3+4+5
assert e[1].const_value() == 13 # 1+2*5+2 assert e[1].const_value() == 13 # 1+2*5+2
assert e[2].const_value() == 21 # (1+2)*(5+2) assert e[2].const_value() == 21 # (1+2)*(5+2)
@ -415,7 +407,7 @@ def test_const_logic_expressions():
result = parse_source(src) result = parse_source(src)
assignments = list(result.all_nodes(Assignment)) assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments] e = [a.nodes[1] for a in assignments]
assert all(x.is_compile_constant() for x in e) assert all(x.is_compiletime_const() for x in e)
assert e[0].const_value() == True assert e[0].const_value() == True
assert e[1].const_value() == False assert e[1].const_value() == False
assert e[2].const_value() == True assert e[2].const_value() == True
@ -441,12 +433,12 @@ def test_const_other_expressions():
result.scope.define_builtin_functions() result.scope.define_builtin_functions()
assignments = list(result.all_nodes(Assignment)) assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments] e = [a.nodes[1] for a in assignments]
assert e[0].is_compile_constant() assert e[0].is_compiletime_const()
assert e[0].const_value() == 0xc123 assert e[0].const_value() == 0xc123
assert not e[1].is_compile_constant() assert not e[1].is_compiletime_const()
with pytest.raises(TypeError): with pytest.raises(TypeError):
e[1].const_value() e[1].const_value()
assert not e[2].is_compile_constant() assert not e[2].is_compiletime_const()
with pytest.raises(TypeError): with pytest.raises(TypeError):
e[2].const_value() e[2].const_value()
@ -495,3 +487,67 @@ def test_vdef_const_folds():
assert vd[2].datatype == DataType.BYTE assert vd[2].datatype == DataType.BYTE
assert isinstance(vd[2].value, LiteralValue) assert isinstance(vd[2].value, LiteralValue)
assert vd[2].value.value == 369 assert vd[2].value.value == 369
def test_vdef_const_expressions():
src = """
~ {
var bvar = 99
var .float fvar = sin(1.2-0.3)
var .float flt2 = -9.87e-6
bvar ++
fvar ++
flt2 ++
bvar += 2+2
fvar += 2+3
flt2 += 2+4
bvar = 2+5
fvar = 2+6
flt2 = 2+7
}
"""
result = parse_source(src)
if isinstance(result, Module):
result.scope.define_builtin_functions()
cf = ConstantFold(result)
cf.fold_constants()
vd = list(result.all_nodes(VarDef))
assert len(vd)==3
assert vd[0].name == "bvar"
assert isinstance(vd[0].value, LiteralValue)
assert vd[0].value.value == 99
assert vd[1].name == "fvar"
assert isinstance(vd[1].value, LiteralValue)
assert type(vd[1].value.value) is float
assert math.isclose(vd[1].value.value, math.sin(0.9))
assert vd[2].name == "flt2"
assert isinstance(vd[2].value, LiteralValue)
assert math.isclose(-9.87e-6, vd[2].value.value)
# test incrdecr assignment target
nodes = list(result.all_nodes(IncrDecr))
assert len(nodes) == 3
assert isinstance(nodes[0].target, SymbolName)
assert nodes[0].target.name == "bvar"
assert isinstance(nodes[1].target, SymbolName)
assert nodes[1].target.name == "fvar"
assert isinstance(nodes[2].target, SymbolName)
assert nodes[2].target.name == "flt2"
# test augassign assignment target
nodes = list(result.all_nodes(AugAssignment))
assert len(nodes) == 3
assert isinstance(nodes[0].left, SymbolName)
assert nodes[0].left.name == "bvar"
assert isinstance(nodes[1].left, SymbolName)
assert nodes[1].left.name == "fvar"
assert isinstance(nodes[2].left, SymbolName)
assert nodes[2].left.name == "flt2"
# test assign assignment target
nodes = list(result.all_nodes(Assignment))
assert len(nodes) == 3
assert isinstance(nodes[0].left, SymbolName)
assert nodes[0].left.name == "bvar"
assert isinstance(nodes[1].left, SymbolName)
assert nodes[1].left.name == "fvar"
assert isinstance(nodes[2].left, SymbolName)
assert nodes[2].left.name == "flt2"

View File

@ -3,13 +3,6 @@ from il65.datatypes import DataType
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError
from il65.plylex import SourceRef from il65.plylex import SourceRef
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
# name = attr.ib(type=str)
# vartype = attr.ib()
# datatype = attr.ib()
# size = attr.ib(type=list, default=None)
# zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
def test_creation(): def test_creation():
sref = SourceRef("test", 1, 1) sref = SourceRef("test", 1, 1)
@ -62,10 +55,8 @@ def test_set_value():
assert v.value.value == "hello" assert v.value.value == "hello"
e = ExpressionWithOperator(operator="-", sourceref=sref) e = ExpressionWithOperator(operator="-", sourceref=sref)
e.left = LiteralValue(value=42, sourceref=sref) e.left = LiteralValue(value=42, sourceref=sref)
assert not e.must_be_constant
v.value = e v.value = e
assert v.value is e assert v.value is e
assert e.must_be_constant
def test_const_value(): def test_const_value():

View File

@ -1,4 +1,5 @@
~ main { ~ main {
var .float flt1 = 9.87e-21
var .float flt = -9.87e-21 var .float flt = -9.87e-21
const .word border = $0099 const .word border = $0099
var counter = 1 var counter = 1