const folding cleanups and explicit notion of assignment LHS

This commit is contained in:
Irmen de Jong 2018-02-18 18:19:51 +01:00
parent a171bb998d
commit de3bca0763
10 changed files with 234 additions and 204 deletions

View File

@ -93,21 +93,21 @@ class PlyParser:
# note: not processing regular assignments, because they can contain multiple targets of different datatype.
# this has to be dealt with anyway later, so we don't bother dealing with it here for just a special case.
if isinstance(node, AugAssignment):
if node.right.is_compile_constant():
if node.right.is_compiletime_const():
_, node.right = coerce_constant_value(datatype_of(node.left, node.my_scope()), node.right, node.right.sourceref)
elif isinstance(node, Goto):
if node.condition is not None and node.condition.is_compile_constant():
if node.condition is not None and node.condition.is_compiletime_const():
_, node.nodes[1] = coerce_constant_value(DataType.WORD, node.nodes[1], node.nodes[1].sourceref) # type: ignore
elif isinstance(node, Return):
if node.value_A is not None and node.value_A.is_compile_constant():
if node.value_A is not None and node.value_A.is_compiletime_const():
_, node.nodes[0] = coerce_constant_value(DataType.BYTE, node.nodes[0], node.nodes[0].sourceref) # type: ignore
if node.value_X is not None and node.value_X.is_compile_constant():
if node.value_X is not None and node.value_X.is_compiletime_const():
_, node.nodes[1] = coerce_constant_value(DataType.BYTE, node.nodes[1], node.nodes[1].sourceref) # type: ignore
if node.value_Y is not None and node.value_Y.is_compile_constant():
if node.value_Y is not None and node.value_Y.is_compiletime_const():
_, node.nodes[2] = coerce_constant_value(DataType.BYTE, node.nodes[2], node.nodes[2].sourceref) # type: ignore
elif isinstance(node, VarDef):
if node.value is not None:
if node.value.is_compile_constant():
if node.value.is_compiletime_const():
_, node.value = coerce_constant_value(datatype_of(node, node.my_scope()), node.value, node.value.sourceref)
except OverflowError as x:
raise ParseError(str(x), node.sourceref)
@ -189,7 +189,7 @@ class PlyParser:
if isinstance(node.right, LiteralValue) and node.right.value == 0:
raise ParseError("division by zero", node.right.sourceref)
elif isinstance(node, VarDef):
if node.value is not None and not node.value.is_compile_constant():
if node.value is not None and not node.value.is_compiletime_const():
raise ParseError("variable initialization value should be a compile-time constant", node.value.sourceref)
elif isinstance(node, Dereference):
if isinstance(node.operand, Register) and node.operand.datatype == DataType.BYTE:

View File

@ -62,8 +62,12 @@ class ConstantFold:
def _process_expression(self, expr: Expression) -> Expression:
# process/simplify all expressions (constant folding etc)
if expr.is_lhs:
if isinstance(expr, (Register, SymbolName, Dereference)):
return expr
raise ParseError("invalid lhs expression type", expr.sourceref)
result = None # type: Expression
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
if expr.is_compiletime_const():
result = self._process_constant_expression(expr, expr.sourceref)
else:
result = self._process_dynamic_expression(expr, expr.sourceref)
@ -74,9 +78,11 @@ class ConstantFold:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
try:
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
except NotCompiletimeConstantError:
pass
if isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
@ -169,26 +175,24 @@ class ConstantFold:
# constant-fold a dynamic expression
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
try:
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
except NotCompiletimeConstantError:
pass
if isinstance(expr, SymbolName):
try:
return self._process_constant_expression(expr, sourceref)
except (ExpressionEvaluationError, NotCompiletimeConstantError):
return expr
elif isinstance(expr, AddressOf):
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
try:
return self._process_constant_expression(expr, sourceref)
except (ExpressionEvaluationError, NotCompiletimeConstantError):
return expr
elif isinstance(expr, SubCall):
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
except (ExpressionEvaluationError, NotCompiletimeConstantError):
if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr
@ -199,12 +203,10 @@ class ConstantFold:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
expr.left.parent = expr
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
try:
return self._process_constant_expression(expr, sourceref)
except (ExpressionEvaluationError, NotCompiletimeConstantError):
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
@ -212,11 +214,9 @@ class ConstantFold:
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = self._process_dynamic_expression(expr.right, right_sourceref)
expr.right.parent = expr
if expr.is_compile_constant():
try:
return self._process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
pass
return expr
try:
return self._process_constant_expression(expr, sourceref)
except (ExpressionEvaluationError, NotCompiletimeConstantError):
return expr
else:
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)

View File

@ -16,7 +16,7 @@ def generate_assignment(ctx: Context) -> None:
assert isinstance(ctx.stmt, Assignment)
assert not isinstance(ctx.stmt.right, Assignment), "assignment should have been flattened"
ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left.nodes, ctx.stmt.right))
ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left, ctx.stmt.right))
# @todo assignment

View File

@ -15,9 +15,9 @@ def generate_goto(ctx: Context) -> None:
ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
if stmt.condition:
if stmt.if_stmt:
_gen_goto_special_if_cond(ctx, stmt)
_gen_goto_cond(ctx, stmt, "true")
else:
_gen_goto_cond(ctx, stmt)
_gen_goto_cond(ctx, stmt, stmt.if_cond)
else:
if stmt.if_stmt:
_gen_goto_special_if(ctx, stmt)
@ -134,12 +134,11 @@ def _gen_goto_unconditional(ctx: Context, stmt: Goto) -> None:
raise CodeError("invalid goto target type", stmt)
def _gen_goto_special_if_cond(ctx: Context, stmt: Goto) -> None:
pass # @todo special if WITH conditional expression
def _gen_goto_cond(ctx: Context, stmt: Goto) -> None:
pass # @todo regular if WITH conditional expression
def _gen_goto_cond(ctx: Context, stmt: Goto, if_cond: str) -> None:
if isinstance(stmt.condition, LiteralValue):
pass # @todo if WITH conditional expression
else:
raise CodeError("no support for evaluating conditional expression yet", stmt) # @todo
def generate_subcall(ctx: Context) -> None:

View File

@ -10,7 +10,7 @@ import datetime
from typing import TextIO, Callable, no_type_check
from ..plylex import print_bold
from ..plyparse import (Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, ZpOptions,
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr, AssignmentTargets)
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr)
from . import CodeError, to_hex, to_mflpt5, Context
from .variables import generate_block_init, generate_block_vars
from .assignment import generate_assignment, generate_aug_assignment
@ -195,22 +195,25 @@ class AssemblyGenerator:
if stmt.value_A:
reg = Register(name="A", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_A)
assignment.mark_lhs()
ctx.stmt = assignment
generate_assignment(ctx)
if stmt.value_X:
reg = Register(name="X", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_X)
assignment.mark_lhs()
ctx.stmt = assignment
generate_assignment(ctx)
if stmt.value_Y:
reg = Register(name="Y", sourceref=stmt.sourceref)
assignment = Assignment(sourceref=stmt.sourceref)
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
assignment.nodes.append(reg)
assignment.nodes.append(stmt.value_Y)
assignment.mark_lhs()
ctx.stmt = assignment
generate_assignment(ctx)
ctx.out("\vrts")

View File

@ -7,7 +7,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
"""
from typing import List, no_type_check, Union
from .datatypes import DataType
from .datatypes import DataType, VarType
from .plyparse import *
from .plylex import print_warning, print_bold
from .constantfold import ConstantFold
@ -117,8 +117,13 @@ class Optimizer:
return True
if isinstance(node1, SymbolName) and isinstance(node2, SymbolName) and node1.name == node2.name:
return True
if isinstance(node1, Dereference) and isinstance(node2, Dereference) and node1.operand == node2.operand:
return True
if isinstance(node1, Dereference) and isinstance(node2, Dereference):
if type(node1.operand) is not type(node2.operand):
return False
if isinstance(node1.operand, (SymbolName, LiteralValue, Register)):
return node1.operand == node2.operand
if not isinstance(node1, AstNode) or not isinstance(node2, AstNode):
raise TypeError("same_target called with invalid type(s)", node1, node2)
return False
@no_type_check
@ -132,7 +137,7 @@ class Optimizer:
continue
expr = assignment.right
if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
num_val = expr.right.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
@ -141,13 +146,13 @@ class Optimizer:
continue
if expr.operator not in ('+', '*', '|', '^'): # associative operators
continue
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
num_val = expr.right.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right):
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.right):
num_val = expr.left.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
@ -161,12 +166,16 @@ class Optimizer:
prev_node = None # type: AstNode
for node in list(scope.nodes):
if isinstance(node, Assignment) and isinstance(prev_node, Assignment):
if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left):
if not node.left.has_memvalue():
scope.remove_node(prev_node)
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: removed superfluous assignment".format(prev_node.sourceref))
if isinstance(node.right, (LiteralValue, Register)) and self._same_target(node.left, prev_node.left):
if isinstance(node.left, SymbolName):
# only optimize if the symbol is not a memory mapped address (volatile memory!)
symdef = node.left.my_scope().lookup(node.left.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
continue
scope.remove_node(prev_node)
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: removed superfluous assignment".format(prev_node.sourceref))
prev_node = node
@no_type_check
@ -178,17 +187,17 @@ class Optimizer:
# @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?)
for assignment in self.module.all_nodes():
if isinstance(assignment, Assignment):
if all(lv == assignment.right for lv in assignment.left.nodes):
if self._same_target(assignment.left, assignment.right):
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
print_warning("{}: removed statement that has no effect (left=right)".format(assignment.sourceref))
elif isinstance(assignment, AugAssignment):
if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)):
if assignment.right.value == 0:
if assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="):
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
print_warning("{}: removed statement that has no effect (aug.assign zero)".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
elif assignment.operator == "*=":
@ -206,9 +215,10 @@ class Optimizer:
elif assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="):
print("{}: shifting result is always zero".format(assignment.sourceref))
new_stmt = Assignment(sourceref=assignment.sourceref)
new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref))
new_stmt.nodes.append(assignment.left)
new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref))
assignment.my_scope().replace_node(assignment, new_stmt)
assignment.mark_lhs()
self.optimizations_performed = True
elif assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256:
howmuch = assignment.right
@ -223,7 +233,7 @@ class Optimizer:
self.optimizations_performed = True
elif assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="):
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
print_warning("{}: removed statement that has no effect (aug.assign identity)".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
@ -231,12 +241,13 @@ class Optimizer:
def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment:
new_assignment = Assignment(sourceref=old_aug_assignment.sourceref)
new_assignment.parent = old_aug_assignment.parent
left = AssignmentTargets(nodes=[old_aug_assignment.left], sourceref=old_aug_assignment.sourceref)
left = old_aug_assignment.left
left.parent = new_assignment
new_assignment.nodes.append(left)
value = LiteralValue(value=constantvalue, sourceref=old_aug_assignment.sourceref)
value.parent = new_assignment
new_assignment.nodes.append(value)
new_assignment.mark_lhs()
return new_assignment
@no_type_check
@ -250,6 +261,7 @@ class Optimizer:
a.nodes.append(lv)
lv.parent = a
a.parent = old_assign.parent
a.mark_lhs()
return a
@no_type_check

View File

@ -22,8 +22,8 @@ __all__ = ["ProgramFormat", "ZpOptions", "math_functions", "builtin_functions",
"UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression",
"Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr",
"ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment",
"InlineAssembly", "AssignmentTargets", "BuiltinFunction", "TokenFilter", "parser", "connect_parents",
"parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition"]
"InlineAssembly", "BuiltinFunction", "TokenFilter", "parser", "connect_parents",
"parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition", "NotCompiletimeConstantError"]
class ProgramFormat(enum.Enum):
@ -55,6 +55,10 @@ class ParseError(Exception):
return "{} {:s}".format(self.sourceref, self.args[0])
class NotCompiletimeConstantError(TypeError):
pass
class ExpressionEvaluationError(ParseError):
pass
@ -348,7 +352,9 @@ class Label(AstNode):
class Expression(AstNode):
# just a common base class for the nodes that are an expression themselves:
# ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference
def is_compile_constant(self) -> bool:
is_lhs = attr.ib(type=bool, init=False, default=False) # left hand side of incrdecr/assignment/augassign?
def is_compiletime_const(self) -> bool:
raise NotImplementedError("implement in subclass")
def const_value(self) -> Union[int, float, bool, str]:
@ -382,11 +388,11 @@ class Register(Expression):
return NotImplemented
return self.name < other.name
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("register doesn't have a constant numeric value", self)
raise NotCompiletimeConstantError("register doesn't have a constant numeric value", self)
@attr.s(cmp=False)
@ -464,7 +470,7 @@ class LiteralValue(Expression):
def const_value(self) -> Union[int, float, bool, str]:
return self.value
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
return True
@ -473,7 +479,7 @@ class AddressOf(Expression):
# no subnodes.
name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str))
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \
@ -486,16 +492,16 @@ class AddressOf(Expression):
return symdef.zp_address
if symdef.vartype == VarType.MEMORY:
return symdef.value.const_value()
raise TypeError("can only take constant address of a memory mapped variable", self)
raise TypeError("should be a vardef to be able to take its address", self)
raise NotCompiletimeConstantError("can only take constant address of a memory mapped variable", self)
raise NotCompiletimeConstantError("should be a vardef to be able to take its address", self)
@attr.s(cmp=False, slots=True)
@attr.s(cmp=True, slots=True)
class SymbolName(Expression):
# no subnodes.
name = attr.ib(type=str)
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST
@ -503,7 +509,7 @@ class SymbolName(Expression):
symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
return symdef.const_value()
raise TypeError("should be a const vardef to be able to take its constant numeric value", self)
raise NotCompiletimeConstantError("should be a const vardef to be able to take its constant numeric value", self)
@attr.s(cmp=False)
@ -531,11 +537,11 @@ class Dereference(Expression):
if self.nodes and not isinstance(self.nodes[0], (SymbolName, LiteralValue, Register)):
raise TypeError("operand of dereference invalid type", self.nodes[0], self.sourceref)
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("dereference is not a constant numeric value")
raise NotCompiletimeConstantError("dereference is not a constant numeric value")
@attr.s(cmp=False)
@ -556,6 +562,8 @@ class IncrDecr(AstNode):
raise ParseError("cannot incr/decr that register", self.sourceref)
assert isinstance(target, (Register, SymbolName, Dereference))
self.nodes.clear()
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
target.is_lhs = True
self.nodes.append(target)
def __attrs_post_init__(self):
@ -569,8 +577,6 @@ class IncrDecr(AstNode):
class ExpressionWithOperator(Expression):
# 2 nodes: left (Expression), right (not present if unary, Expression if not unary)
operator = attr.ib(type=str)
# when evaluating the expression, does it have to be a compile-time constant value?
must_be_constant = attr.ib(type=bool, init=False, default=False)
@property
def unary(self) -> bool:
@ -612,7 +618,7 @@ class ExpressionWithOperator(Expression):
elif self.operator == "not":
return not cv[0]
elif self.operator == "&":
raise TypeError("the address-of operator should have been parsed into an AddressOf node")
raise NotCompiletimeConstantError("the address-of operator should have been parsed into an AddressOf node")
else:
raise ValueError("invalid unary operator: "+self.operator, self.sourceref)
else:
@ -664,11 +670,11 @@ class ExpressionWithOperator(Expression):
raise ValueError("invalid operator: "+self.operator, self.sourceref)
@no_type_check
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
if len(self.nodes) == 1:
return self.nodes[0].is_compile_constant()
return self.nodes[0].is_compiletime_const()
elif len(self.nodes) == 2:
return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant()
return self.nodes[0].is_compiletime_const() and self.nodes[1].is_compiletime_const()
raise ValueError("should have 1 or 2 nodes")
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
@ -742,7 +748,7 @@ class SubCall(Expression):
def arguments(self) -> CallArguments:
return self.nodes[2] # type: ignore
def is_compile_constant(self) -> bool:
def is_compiletime_const(self) -> bool:
if isinstance(self.nodes[0], SymbolName):
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
if isinstance(symdef, BuiltinFunction):
@ -756,7 +762,7 @@ class SubCall(Expression):
if isinstance(symdef, BuiltinFunction):
arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes]
return symdef.func(*arguments)
raise TypeError("subroutine call is not a constant value", self)
raise NotCompiletimeConstantError("subroutine call is not a constant value", self)
@attr.s(cmp=False, slots=True, repr=False)
@ -779,13 +785,10 @@ class VarDef(AstNode):
self.nodes[0] = value
else:
self.nodes.append(value)
if isinstance(value, ExpressionWithOperator):
# an expression in a vardef should evaluate to a compile-time constant:
value.must_be_constant = True
def const_value(self) -> Union[int, float, bool, str]:
if self.vartype != VarType.CONST:
raise TypeError("not a constant value", self)
raise NotCompiletimeConstantError("not a constant value", self)
if self.nodes and isinstance(self.nodes[0], Expression):
return self.nodes[0].const_value()
raise ValueError("no value", self)
@ -841,59 +844,14 @@ class Return(AstNode):
return self.nodes[2] if len(self.nodes) >= 3 else None # type: ignore
@attr.s(cmp=False, slots=True, repr=False)
class AssignmentTargets(AstNode):
# a list of one or more assignment targets (Register, SymbolName, or Dereference).
nodes = attr.ib(type=list, init=True) # requires nodes in __init__
def has_memvalue(self) -> bool:
for t in self.nodes:
if isinstance(t, Dereference):
return True
if isinstance(t, SymbolName):
symdef = self.my_scope().lookup(t.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
return True
return False
def same_targets(self, other: 'AssignmentTargets') -> bool:
if len(self.nodes) != len(other.nodes):
return False
# @todo be able to compare targets in different order as well (sort them)
for t1, t2 in zip(self.nodes, other.nodes):
if type(t1) is not type(t2):
return False
if isinstance(t1, Register):
if t1 != t2: # __eq__ is defined
return False
elif isinstance(t1, SymbolName):
if t1.name != t2.name:
return False
elif isinstance(t1, Dereference):
if t1.size != t2.size or t1.datatype != t2.datatype:
return False
op1, op2 = t1.operand, t2.operand
if type(op1) is not type(op2):
return False
if isinstance(op1, SymbolName):
if op1.name != op2.name:
return False
else:
if op1 != op2:
return False
else:
return False
return True
@attr.s(cmp=False, slots=True, repr=False)
class Assignment(AstNode):
# can be single- or multi-assignment
# has two subnodes: left (=AssignmentTargets) and right (=Expression,
# has two subnodes: left (=Register/SymbolName/Dereference ) and right (=Expression,
# or another Assignment but those will be converted into multi assign)
@property
def left(self) -> AssignmentTargets:
def left(self) -> Union[Register, SymbolName, Dereference]:
return self.nodes[0] # type: ignore
@property
@ -905,6 +863,10 @@ class Assignment(AstNode):
assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue
def mark_lhs(self):
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
self.nodes[0].is_lhs = True
@attr.s(cmp=False, slots=True, repr=False)
class AugAssignment(AstNode):
@ -924,6 +886,10 @@ class AugAssignment(AstNode):
assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue
def mark_lhs(self):
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
self.nodes[0].is_lhs = True
def datatype_of(targetnode: AstNode, scope: Scope) -> DataType:
# tries to determine the DataType of an assignment target node
@ -1512,8 +1478,9 @@ def p_assignment(p):
| assignment_target IS assignment
"""
p[0] = Assignment(sourceref=_token_sref(p, 2))
p[0].nodes.append(AssignmentTargets(nodes=[p[1]], sourceref=p[0].sourceref))
p[0].nodes.append(p[1])
p[0].nodes.append(p[3])
p[0].mark_lhs()
def p_aug_assignment(p):
@ -1523,6 +1490,7 @@ def p_aug_assignment(p):
p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2))
p[0].nodes.append(p[1])
p[0].nodes.append(p[3])
p[0].mark_lhs()
precedence = (

View File

@ -164,16 +164,14 @@ def test_block_nodes():
assert len(sub2.scope.nodes) > 0
test_source_2 = """
~ {
999(1,2)
[zz]()
}
"""
def test_parser_2():
result = parse_source(test_source_2)
src = """
~ {
999(1,2)
[zz]()
}
"""
result = parse_source(src)
block = result.scope.nodes[0]
call = block.scope.nodes[0]
assert isinstance(call, SubCall)
@ -187,18 +185,16 @@ def test_parser_2():
assert call.target.operand.name == "zz"
test_source_3 = """
~ {
[$c000.word] = 5
[$c000 .byte] = 5
[AX .word] = 5
[AX .float] = 5
}
"""
def test_typespec():
result = parse_source(test_source_3)
src = """
~ {
[$c000.word] = 5
[$c000 .byte] = 5
[AX .word] = 5
[AX .float] = 5
}
"""
result = parse_source(src)
block = result.scope.nodes[0]
assignment1, assignment2, assignment3, assignment4 = block.scope.nodes
assert assignment1.right.value == 5
@ -209,10 +205,10 @@ def test_typespec():
assert len(assignment2.left.nodes) == 1
assert len(assignment3.left.nodes) == 1
assert len(assignment4.left.nodes) == 1
t1 = assignment1.left.nodes[0]
t2 = assignment2.left.nodes[0]
t3 = assignment3.left.nodes[0]
t4 = assignment4.left.nodes[0]
t1 = assignment1.left
t2 = assignment2.left
t3 = assignment3.left
t4 = assignment4.left
assert isinstance(t1, Dereference)
assert isinstance(t2, Dereference)
assert isinstance(t3, Dereference)
@ -235,20 +231,18 @@ def test_typespec():
assert t4.size is None
test_source_4 = """
~ {
var x1 = '@'
var x2 = 'π'
var x3 = 'abc'
A = '@'
A = 'π'
A = 'abc'
}
"""
def test_char_string():
result = parse_source(test_source_4)
src = """
~ {
var x1 = '@'
var x2 = 'π'
var x3 = 'abc'
A = '@'
A = 'π'
A = 'abc'
}
"""
result = parse_source(src)
block = result.scope.nodes[0]
var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes
assert var1.value.value == '@'
@ -260,18 +254,16 @@ def test_char_string():
# note: the actual one-charactor-to-bytevalue conversion is done at the very latest, when issuing an assignment statement
test_source_5 = """
~ {
var x1 = true
var x2 = false
A = true
A = false
}
"""
def test_boolean_int():
result = parse_source(test_source_5)
src = """
~ {
var x1 = true
var x2 = false
A = true
A = false
}
"""
result = parse_source(src)
block = result.scope.nodes[0]
var1, var2, assgn1, assgn2, = block.scope.nodes
assert type(var1.value.value) is int and var1.value.value == 1
@ -381,7 +373,7 @@ def test_const_numeric_expressions():
result.scope.define_builtin_functions()
assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments]
assert all(x.is_compile_constant() for x in e)
assert all(x.is_compiletime_const() for x in e)
assert e[0].const_value() == 15 # 1+2+3+4+5
assert e[1].const_value() == 13 # 1+2*5+2
assert e[2].const_value() == 21 # (1+2)*(5+2)
@ -415,7 +407,7 @@ def test_const_logic_expressions():
result = parse_source(src)
assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments]
assert all(x.is_compile_constant() for x in e)
assert all(x.is_compiletime_const() for x in e)
assert e[0].const_value() == True
assert e[1].const_value() == False
assert e[2].const_value() == True
@ -441,12 +433,12 @@ def test_const_other_expressions():
result.scope.define_builtin_functions()
assignments = list(result.all_nodes(Assignment))
e = [a.nodes[1] for a in assignments]
assert e[0].is_compile_constant()
assert e[0].is_compiletime_const()
assert e[0].const_value() == 0xc123
assert not e[1].is_compile_constant()
assert not e[1].is_compiletime_const()
with pytest.raises(TypeError):
e[1].const_value()
assert not e[2].is_compile_constant()
assert not e[2].is_compiletime_const()
with pytest.raises(TypeError):
e[2].const_value()
@ -495,3 +487,67 @@ def test_vdef_const_folds():
assert vd[2].datatype == DataType.BYTE
assert isinstance(vd[2].value, LiteralValue)
assert vd[2].value.value == 369
def test_vdef_const_expressions():
src = """
~ {
var bvar = 99
var .float fvar = sin(1.2-0.3)
var .float flt2 = -9.87e-6
bvar ++
fvar ++
flt2 ++
bvar += 2+2
fvar += 2+3
flt2 += 2+4
bvar = 2+5
fvar = 2+6
flt2 = 2+7
}
"""
result = parse_source(src)
if isinstance(result, Module):
result.scope.define_builtin_functions()
cf = ConstantFold(result)
cf.fold_constants()
vd = list(result.all_nodes(VarDef))
assert len(vd)==3
assert vd[0].name == "bvar"
assert isinstance(vd[0].value, LiteralValue)
assert vd[0].value.value == 99
assert vd[1].name == "fvar"
assert isinstance(vd[1].value, LiteralValue)
assert type(vd[1].value.value) is float
assert math.isclose(vd[1].value.value, math.sin(0.9))
assert vd[2].name == "flt2"
assert isinstance(vd[2].value, LiteralValue)
assert math.isclose(-9.87e-6, vd[2].value.value)
# test incrdecr assignment target
nodes = list(result.all_nodes(IncrDecr))
assert len(nodes) == 3
assert isinstance(nodes[0].target, SymbolName)
assert nodes[0].target.name == "bvar"
assert isinstance(nodes[1].target, SymbolName)
assert nodes[1].target.name == "fvar"
assert isinstance(nodes[2].target, SymbolName)
assert nodes[2].target.name == "flt2"
# test augassign assignment target
nodes = list(result.all_nodes(AugAssignment))
assert len(nodes) == 3
assert isinstance(nodes[0].left, SymbolName)
assert nodes[0].left.name == "bvar"
assert isinstance(nodes[1].left, SymbolName)
assert nodes[1].left.name == "fvar"
assert isinstance(nodes[2].left, SymbolName)
assert nodes[2].left.name == "flt2"
# test assign assignment target
nodes = list(result.all_nodes(Assignment))
assert len(nodes) == 3
assert isinstance(nodes[0].left, SymbolName)
assert nodes[0].left.name == "bvar"
assert isinstance(nodes[1].left, SymbolName)
assert nodes[1].left.name == "fvar"
assert isinstance(nodes[2].left, SymbolName)
assert nodes[2].left.name == "flt2"

View File

@ -3,13 +3,6 @@ from il65.datatypes import DataType
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError
from il65.plylex import SourceRef
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
# name = attr.ib(type=str)
# vartype = attr.ib()
# datatype = attr.ib()
# size = attr.ib(type=list, default=None)
# zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
def test_creation():
sref = SourceRef("test", 1, 1)
@ -62,10 +55,8 @@ def test_set_value():
assert v.value.value == "hello"
e = ExpressionWithOperator(operator="-", sourceref=sref)
e.left = LiteralValue(value=42, sourceref=sref)
assert not e.must_be_constant
v.value = e
assert v.value is e
assert e.must_be_constant
def test_const_value():

View File

@ -1,4 +1,5 @@
~ main {
var .float flt1 = 9.87e-21
var .float flt = -9.87e-21
const .word border = $0099
var counter = 1