mirror of
https://github.com/irmen/prog8.git
synced 2025-01-11 13:29:45 +00:00
const folding cleanups and explicit notion of assignment LHS
This commit is contained in:
parent
a171bb998d
commit
de3bca0763
@ -93,21 +93,21 @@ class PlyParser:
|
||||
# note: not processing regular assignments, because they can contain multiple targets of different datatype.
|
||||
# this has to be dealt with anyway later, so we don't bother dealing with it here for just a special case.
|
||||
if isinstance(node, AugAssignment):
|
||||
if node.right.is_compile_constant():
|
||||
if node.right.is_compiletime_const():
|
||||
_, node.right = coerce_constant_value(datatype_of(node.left, node.my_scope()), node.right, node.right.sourceref)
|
||||
elif isinstance(node, Goto):
|
||||
if node.condition is not None and node.condition.is_compile_constant():
|
||||
if node.condition is not None and node.condition.is_compiletime_const():
|
||||
_, node.nodes[1] = coerce_constant_value(DataType.WORD, node.nodes[1], node.nodes[1].sourceref) # type: ignore
|
||||
elif isinstance(node, Return):
|
||||
if node.value_A is not None and node.value_A.is_compile_constant():
|
||||
if node.value_A is not None and node.value_A.is_compiletime_const():
|
||||
_, node.nodes[0] = coerce_constant_value(DataType.BYTE, node.nodes[0], node.nodes[0].sourceref) # type: ignore
|
||||
if node.value_X is not None and node.value_X.is_compile_constant():
|
||||
if node.value_X is not None and node.value_X.is_compiletime_const():
|
||||
_, node.nodes[1] = coerce_constant_value(DataType.BYTE, node.nodes[1], node.nodes[1].sourceref) # type: ignore
|
||||
if node.value_Y is not None and node.value_Y.is_compile_constant():
|
||||
if node.value_Y is not None and node.value_Y.is_compiletime_const():
|
||||
_, node.nodes[2] = coerce_constant_value(DataType.BYTE, node.nodes[2], node.nodes[2].sourceref) # type: ignore
|
||||
elif isinstance(node, VarDef):
|
||||
if node.value is not None:
|
||||
if node.value.is_compile_constant():
|
||||
if node.value.is_compiletime_const():
|
||||
_, node.value = coerce_constant_value(datatype_of(node, node.my_scope()), node.value, node.value.sourceref)
|
||||
except OverflowError as x:
|
||||
raise ParseError(str(x), node.sourceref)
|
||||
@ -189,7 +189,7 @@ class PlyParser:
|
||||
if isinstance(node.right, LiteralValue) and node.right.value == 0:
|
||||
raise ParseError("division by zero", node.right.sourceref)
|
||||
elif isinstance(node, VarDef):
|
||||
if node.value is not None and not node.value.is_compile_constant():
|
||||
if node.value is not None and not node.value.is_compiletime_const():
|
||||
raise ParseError("variable initialization value should be a compile-time constant", node.value.sourceref)
|
||||
elif isinstance(node, Dereference):
|
||||
if isinstance(node.operand, Register) and node.operand.datatype == DataType.BYTE:
|
||||
|
@ -62,8 +62,12 @@ class ConstantFold:
|
||||
|
||||
def _process_expression(self, expr: Expression) -> Expression:
|
||||
# process/simplify all expressions (constant folding etc)
|
||||
if expr.is_lhs:
|
||||
if isinstance(expr, (Register, SymbolName, Dereference)):
|
||||
return expr
|
||||
raise ParseError("invalid lhs expression type", expr.sourceref)
|
||||
result = None # type: Expression
|
||||
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
|
||||
if expr.is_compiletime_const():
|
||||
result = self._process_constant_expression(expr, expr.sourceref)
|
||||
else:
|
||||
result = self._process_dynamic_expression(expr, expr.sourceref)
|
||||
@ -74,9 +78,11 @@ class ConstantFold:
|
||||
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
|
||||
if isinstance(expr, LiteralValue):
|
||||
return expr
|
||||
if expr.is_compile_constant():
|
||||
try:
|
||||
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
|
||||
elif isinstance(expr, SymbolName):
|
||||
except NotCompiletimeConstantError:
|
||||
pass
|
||||
if isinstance(expr, SymbolName):
|
||||
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
|
||||
if isinstance(value, VarDef):
|
||||
if value.vartype == VarType.MEMORY:
|
||||
@ -169,26 +175,24 @@ class ConstantFold:
|
||||
# constant-fold a dynamic expression
|
||||
if isinstance(expr, LiteralValue):
|
||||
return expr
|
||||
if expr.is_compile_constant():
|
||||
try:
|
||||
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
|
||||
elif isinstance(expr, SymbolName):
|
||||
if expr.is_compile_constant():
|
||||
except NotCompiletimeConstantError:
|
||||
pass
|
||||
if isinstance(expr, SymbolName):
|
||||
try:
|
||||
return self._process_constant_expression(expr, sourceref)
|
||||
except ExpressionEvaluationError:
|
||||
pass
|
||||
except (ExpressionEvaluationError, NotCompiletimeConstantError):
|
||||
return expr
|
||||
elif isinstance(expr, AddressOf):
|
||||
if expr.is_compile_constant():
|
||||
try:
|
||||
return self._process_constant_expression(expr, sourceref)
|
||||
except ExpressionEvaluationError:
|
||||
pass
|
||||
except (ExpressionEvaluationError, NotCompiletimeConstantError):
|
||||
return expr
|
||||
elif isinstance(expr, SubCall):
|
||||
try:
|
||||
return self._process_constant_expression(expr, sourceref)
|
||||
except ExpressionEvaluationError:
|
||||
except (ExpressionEvaluationError, NotCompiletimeConstantError):
|
||||
if isinstance(expr.target, SymbolName):
|
||||
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
|
||||
return expr
|
||||
@ -199,11 +203,9 @@ class ConstantFold:
|
||||
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
|
||||
expr.left = self._process_dynamic_expression(expr.left, left_sourceref)
|
||||
expr.left.parent = expr
|
||||
if expr.is_compile_constant():
|
||||
try:
|
||||
return self._process_constant_expression(expr, sourceref)
|
||||
except ExpressionEvaluationError:
|
||||
pass
|
||||
except (ExpressionEvaluationError, NotCompiletimeConstantError):
|
||||
return expr
|
||||
else:
|
||||
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
|
||||
@ -212,11 +214,9 @@ class ConstantFold:
|
||||
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
|
||||
expr.right = self._process_dynamic_expression(expr.right, right_sourceref)
|
||||
expr.right.parent = expr
|
||||
if expr.is_compile_constant():
|
||||
try:
|
||||
return self._process_constant_expression(expr, sourceref)
|
||||
except ExpressionEvaluationError:
|
||||
pass
|
||||
except (ExpressionEvaluationError, NotCompiletimeConstantError):
|
||||
return expr
|
||||
else:
|
||||
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
|
||||
|
@ -16,7 +16,7 @@ def generate_assignment(ctx: Context) -> None:
|
||||
assert isinstance(ctx.stmt, Assignment)
|
||||
assert not isinstance(ctx.stmt.right, Assignment), "assignment should have been flattened"
|
||||
ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
|
||||
ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left.nodes, ctx.stmt.right))
|
||||
ctx.out("\v; @todo assignment: {} = {}".format(ctx.stmt.left, ctx.stmt.right))
|
||||
# @todo assignment
|
||||
|
||||
|
||||
|
@ -15,9 +15,9 @@ def generate_goto(ctx: Context) -> None:
|
||||
ctx.out("\v\t\t\t; " + ctx.stmt.lineref)
|
||||
if stmt.condition:
|
||||
if stmt.if_stmt:
|
||||
_gen_goto_special_if_cond(ctx, stmt)
|
||||
_gen_goto_cond(ctx, stmt, "true")
|
||||
else:
|
||||
_gen_goto_cond(ctx, stmt)
|
||||
_gen_goto_cond(ctx, stmt, stmt.if_cond)
|
||||
else:
|
||||
if stmt.if_stmt:
|
||||
_gen_goto_special_if(ctx, stmt)
|
||||
@ -134,12 +134,11 @@ def _gen_goto_unconditional(ctx: Context, stmt: Goto) -> None:
|
||||
raise CodeError("invalid goto target type", stmt)
|
||||
|
||||
|
||||
def _gen_goto_special_if_cond(ctx: Context, stmt: Goto) -> None:
|
||||
pass # @todo special if WITH conditional expression
|
||||
|
||||
|
||||
def _gen_goto_cond(ctx: Context, stmt: Goto) -> None:
|
||||
pass # @todo regular if WITH conditional expression
|
||||
def _gen_goto_cond(ctx: Context, stmt: Goto, if_cond: str) -> None:
|
||||
if isinstance(stmt.condition, LiteralValue):
|
||||
pass # @todo if WITH conditional expression
|
||||
else:
|
||||
raise CodeError("no support for evaluating conditional expression yet", stmt) # @todo
|
||||
|
||||
|
||||
def generate_subcall(ctx: Context) -> None:
|
||||
|
@ -10,7 +10,7 @@ import datetime
|
||||
from typing import TextIO, Callable, no_type_check
|
||||
from ..plylex import print_bold
|
||||
from ..plyparse import (Module, ProgramFormat, Block, Directive, VarDef, Label, Subroutine, ZpOptions,
|
||||
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr, AssignmentTargets)
|
||||
InlineAssembly, Return, Register, Goto, SubCall, Assignment, AugAssignment, IncrDecr)
|
||||
from . import CodeError, to_hex, to_mflpt5, Context
|
||||
from .variables import generate_block_init, generate_block_vars
|
||||
from .assignment import generate_assignment, generate_aug_assignment
|
||||
@ -195,22 +195,25 @@ class AssemblyGenerator:
|
||||
if stmt.value_A:
|
||||
reg = Register(name="A", sourceref=stmt.sourceref)
|
||||
assignment = Assignment(sourceref=stmt.sourceref)
|
||||
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
|
||||
assignment.nodes.append(reg)
|
||||
assignment.nodes.append(stmt.value_A)
|
||||
assignment.mark_lhs()
|
||||
ctx.stmt = assignment
|
||||
generate_assignment(ctx)
|
||||
if stmt.value_X:
|
||||
reg = Register(name="X", sourceref=stmt.sourceref)
|
||||
assignment = Assignment(sourceref=stmt.sourceref)
|
||||
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
|
||||
assignment.nodes.append(reg)
|
||||
assignment.nodes.append(stmt.value_X)
|
||||
assignment.mark_lhs()
|
||||
ctx.stmt = assignment
|
||||
generate_assignment(ctx)
|
||||
if stmt.value_Y:
|
||||
reg = Register(name="Y", sourceref=stmt.sourceref)
|
||||
assignment = Assignment(sourceref=stmt.sourceref)
|
||||
assignment.nodes.append(AssignmentTargets(nodes=[reg], sourceref=stmt.sourceref))
|
||||
assignment.nodes.append(reg)
|
||||
assignment.nodes.append(stmt.value_Y)
|
||||
assignment.mark_lhs()
|
||||
ctx.stmt = assignment
|
||||
generate_assignment(ctx)
|
||||
ctx.out("\vrts")
|
||||
|
@ -7,7 +7,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
|
||||
"""
|
||||
|
||||
from typing import List, no_type_check, Union
|
||||
from .datatypes import DataType
|
||||
from .datatypes import DataType, VarType
|
||||
from .plyparse import *
|
||||
from .plylex import print_warning, print_bold
|
||||
from .constantfold import ConstantFold
|
||||
@ -117,8 +117,13 @@ class Optimizer:
|
||||
return True
|
||||
if isinstance(node1, SymbolName) and isinstance(node2, SymbolName) and node1.name == node2.name:
|
||||
return True
|
||||
if isinstance(node1, Dereference) and isinstance(node2, Dereference) and node1.operand == node2.operand:
|
||||
return True
|
||||
if isinstance(node1, Dereference) and isinstance(node2, Dereference):
|
||||
if type(node1.operand) is not type(node2.operand):
|
||||
return False
|
||||
if isinstance(node1.operand, (SymbolName, LiteralValue, Register)):
|
||||
return node1.operand == node2.operand
|
||||
if not isinstance(node1, AstNode) or not isinstance(node2, AstNode):
|
||||
raise TypeError("same_target called with invalid type(s)", node1, node2)
|
||||
return False
|
||||
|
||||
@no_type_check
|
||||
@ -132,7 +137,7 @@ class Optimizer:
|
||||
continue
|
||||
expr = assignment.right
|
||||
if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators
|
||||
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
|
||||
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
|
||||
num_val = expr.right.const_value()
|
||||
operator = expr.operator + '='
|
||||
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
|
||||
@ -141,13 +146,13 @@ class Optimizer:
|
||||
continue
|
||||
if expr.operator not in ('+', '*', '|', '^'): # associative operators
|
||||
continue
|
||||
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
|
||||
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.left):
|
||||
num_val = expr.right.const_value()
|
||||
operator = expr.operator + '='
|
||||
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
|
||||
assignment.my_scope().replace_node(assignment, aug_assign)
|
||||
self.optimizations_performed = True
|
||||
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right):
|
||||
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left, expr.right):
|
||||
num_val = expr.left.const_value()
|
||||
operator = expr.operator + '='
|
||||
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
|
||||
@ -161,8 +166,12 @@ class Optimizer:
|
||||
prev_node = None # type: AstNode
|
||||
for node in list(scope.nodes):
|
||||
if isinstance(node, Assignment) and isinstance(prev_node, Assignment):
|
||||
if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left):
|
||||
if not node.left.has_memvalue():
|
||||
if isinstance(node.right, (LiteralValue, Register)) and self._same_target(node.left, prev_node.left):
|
||||
if isinstance(node.left, SymbolName):
|
||||
# only optimize if the symbol is not a memory mapped address (volatile memory!)
|
||||
symdef = node.left.my_scope().lookup(node.left.name)
|
||||
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
|
||||
continue
|
||||
scope.remove_node(prev_node)
|
||||
self.optimizations_performed = True
|
||||
self.num_warnings += 1
|
||||
@ -178,17 +187,17 @@ class Optimizer:
|
||||
# @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?)
|
||||
for assignment in self.module.all_nodes():
|
||||
if isinstance(assignment, Assignment):
|
||||
if all(lv == assignment.right for lv in assignment.left.nodes):
|
||||
if self._same_target(assignment.left, assignment.right):
|
||||
assignment.my_scope().remove_node(assignment)
|
||||
self.optimizations_performed = True
|
||||
self.num_warnings += 1
|
||||
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
|
||||
print_warning("{}: removed statement that has no effect (left=right)".format(assignment.sourceref))
|
||||
elif isinstance(assignment, AugAssignment):
|
||||
if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)):
|
||||
if assignment.right.value == 0:
|
||||
if assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="):
|
||||
self.num_warnings += 1
|
||||
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
|
||||
print_warning("{}: removed statement that has no effect (aug.assign zero)".format(assignment.sourceref))
|
||||
assignment.my_scope().remove_node(assignment)
|
||||
self.optimizations_performed = True
|
||||
elif assignment.operator == "*=":
|
||||
@ -206,9 +215,10 @@ class Optimizer:
|
||||
elif assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="):
|
||||
print("{}: shifting result is always zero".format(assignment.sourceref))
|
||||
new_stmt = Assignment(sourceref=assignment.sourceref)
|
||||
new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref))
|
||||
new_stmt.nodes.append(assignment.left)
|
||||
new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref))
|
||||
assignment.my_scope().replace_node(assignment, new_stmt)
|
||||
assignment.mark_lhs()
|
||||
self.optimizations_performed = True
|
||||
elif assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256:
|
||||
howmuch = assignment.right
|
||||
@ -223,7 +233,7 @@ class Optimizer:
|
||||
self.optimizations_performed = True
|
||||
elif assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="):
|
||||
self.num_warnings += 1
|
||||
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
|
||||
print_warning("{}: removed statement that has no effect (aug.assign identity)".format(assignment.sourceref))
|
||||
assignment.my_scope().remove_node(assignment)
|
||||
self.optimizations_performed = True
|
||||
|
||||
@ -231,12 +241,13 @@ class Optimizer:
|
||||
def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment:
|
||||
new_assignment = Assignment(sourceref=old_aug_assignment.sourceref)
|
||||
new_assignment.parent = old_aug_assignment.parent
|
||||
left = AssignmentTargets(nodes=[old_aug_assignment.left], sourceref=old_aug_assignment.sourceref)
|
||||
left = old_aug_assignment.left
|
||||
left.parent = new_assignment
|
||||
new_assignment.nodes.append(left)
|
||||
value = LiteralValue(value=constantvalue, sourceref=old_aug_assignment.sourceref)
|
||||
value.parent = new_assignment
|
||||
new_assignment.nodes.append(value)
|
||||
new_assignment.mark_lhs()
|
||||
return new_assignment
|
||||
|
||||
@no_type_check
|
||||
@ -250,6 +261,7 @@ class Optimizer:
|
||||
a.nodes.append(lv)
|
||||
lv.parent = a
|
||||
a.parent = old_assign.parent
|
||||
a.mark_lhs()
|
||||
return a
|
||||
|
||||
@no_type_check
|
||||
|
116
il65/plyparse.py
116
il65/plyparse.py
@ -22,8 +22,8 @@ __all__ = ["ProgramFormat", "ZpOptions", "math_functions", "builtin_functions",
|
||||
"UndefinedSymbolError", "AstNode", "Directive", "Scope", "Block", "Module", "Label", "Expression",
|
||||
"Register", "Subroutine", "LiteralValue", "AddressOf", "SymbolName", "Dereference", "IncrDecr",
|
||||
"ExpressionWithOperator", "Goto", "SubCall", "VarDef", "Return", "Assignment", "AugAssignment",
|
||||
"InlineAssembly", "AssignmentTargets", "BuiltinFunction", "TokenFilter", "parser", "connect_parents",
|
||||
"parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition"]
|
||||
"InlineAssembly", "BuiltinFunction", "TokenFilter", "parser", "connect_parents",
|
||||
"parse_file", "coerce_constant_value", "datatype_of", "check_symbol_definition", "NotCompiletimeConstantError"]
|
||||
|
||||
|
||||
class ProgramFormat(enum.Enum):
|
||||
@ -55,6 +55,10 @@ class ParseError(Exception):
|
||||
return "{} {:s}".format(self.sourceref, self.args[0])
|
||||
|
||||
|
||||
class NotCompiletimeConstantError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class ExpressionEvaluationError(ParseError):
|
||||
pass
|
||||
|
||||
@ -348,7 +352,9 @@ class Label(AstNode):
|
||||
class Expression(AstNode):
|
||||
# just a common base class for the nodes that are an expression themselves:
|
||||
# ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference
|
||||
def is_compile_constant(self) -> bool:
|
||||
is_lhs = attr.ib(type=bool, init=False, default=False) # left hand side of incrdecr/assignment/augassign?
|
||||
|
||||
def is_compiletime_const(self) -> bool:
|
||||
raise NotImplementedError("implement in subclass")
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
@ -382,11 +388,11 @@ class Register(Expression):
|
||||
return NotImplemented
|
||||
return self.name < other.name
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
return False
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
raise TypeError("register doesn't have a constant numeric value", self)
|
||||
raise NotCompiletimeConstantError("register doesn't have a constant numeric value", self)
|
||||
|
||||
|
||||
@attr.s(cmp=False)
|
||||
@ -464,7 +470,7 @@ class LiteralValue(Expression):
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
return self.value
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@ -473,7 +479,7 @@ class AddressOf(Expression):
|
||||
# no subnodes.
|
||||
name = attr.ib(type=str, validator=attr.validators._InstanceOfValidator(type=str))
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
# address-of can be a compile time constant if the operand is a memory mapped variable or ZP variable
|
||||
symdef = self.my_scope().lookup(self.name)
|
||||
return isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY \
|
||||
@ -486,16 +492,16 @@ class AddressOf(Expression):
|
||||
return symdef.zp_address
|
||||
if symdef.vartype == VarType.MEMORY:
|
||||
return symdef.value.const_value()
|
||||
raise TypeError("can only take constant address of a memory mapped variable", self)
|
||||
raise TypeError("should be a vardef to be able to take its address", self)
|
||||
raise NotCompiletimeConstantError("can only take constant address of a memory mapped variable", self)
|
||||
raise NotCompiletimeConstantError("should be a vardef to be able to take its address", self)
|
||||
|
||||
|
||||
@attr.s(cmp=False, slots=True)
|
||||
@attr.s(cmp=True, slots=True)
|
||||
class SymbolName(Expression):
|
||||
# no subnodes.
|
||||
name = attr.ib(type=str)
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
symdef = self.my_scope().lookup(self.name)
|
||||
return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST
|
||||
|
||||
@ -503,7 +509,7 @@ class SymbolName(Expression):
|
||||
symdef = self.my_scope().lookup(self.name)
|
||||
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
|
||||
return symdef.const_value()
|
||||
raise TypeError("should be a const vardef to be able to take its constant numeric value", self)
|
||||
raise NotCompiletimeConstantError("should be a const vardef to be able to take its constant numeric value", self)
|
||||
|
||||
|
||||
@attr.s(cmp=False)
|
||||
@ -531,11 +537,11 @@ class Dereference(Expression):
|
||||
if self.nodes and not isinstance(self.nodes[0], (SymbolName, LiteralValue, Register)):
|
||||
raise TypeError("operand of dereference invalid type", self.nodes[0], self.sourceref)
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
return False
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
raise TypeError("dereference is not a constant numeric value")
|
||||
raise NotCompiletimeConstantError("dereference is not a constant numeric value")
|
||||
|
||||
|
||||
@attr.s(cmp=False)
|
||||
@ -556,6 +562,8 @@ class IncrDecr(AstNode):
|
||||
raise ParseError("cannot incr/decr that register", self.sourceref)
|
||||
assert isinstance(target, (Register, SymbolName, Dereference))
|
||||
self.nodes.clear()
|
||||
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
|
||||
target.is_lhs = True
|
||||
self.nodes.append(target)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
@ -569,8 +577,6 @@ class IncrDecr(AstNode):
|
||||
class ExpressionWithOperator(Expression):
|
||||
# 2 nodes: left (Expression), right (not present if unary, Expression if not unary)
|
||||
operator = attr.ib(type=str)
|
||||
# when evaluating the expression, does it have to be a compile-time constant value?
|
||||
must_be_constant = attr.ib(type=bool, init=False, default=False)
|
||||
|
||||
@property
|
||||
def unary(self) -> bool:
|
||||
@ -612,7 +618,7 @@ class ExpressionWithOperator(Expression):
|
||||
elif self.operator == "not":
|
||||
return not cv[0]
|
||||
elif self.operator == "&":
|
||||
raise TypeError("the address-of operator should have been parsed into an AddressOf node")
|
||||
raise NotCompiletimeConstantError("the address-of operator should have been parsed into an AddressOf node")
|
||||
else:
|
||||
raise ValueError("invalid unary operator: "+self.operator, self.sourceref)
|
||||
else:
|
||||
@ -664,11 +670,11 @@ class ExpressionWithOperator(Expression):
|
||||
raise ValueError("invalid operator: "+self.operator, self.sourceref)
|
||||
|
||||
@no_type_check
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
if len(self.nodes) == 1:
|
||||
return self.nodes[0].is_compile_constant()
|
||||
return self.nodes[0].is_compiletime_const()
|
||||
elif len(self.nodes) == 2:
|
||||
return self.nodes[0].is_compile_constant() and self.nodes[1].is_compile_constant()
|
||||
return self.nodes[0].is_compiletime_const() and self.nodes[1].is_compiletime_const()
|
||||
raise ValueError("should have 1 or 2 nodes")
|
||||
|
||||
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
|
||||
@ -742,7 +748,7 @@ class SubCall(Expression):
|
||||
def arguments(self) -> CallArguments:
|
||||
return self.nodes[2] # type: ignore
|
||||
|
||||
def is_compile_constant(self) -> bool:
|
||||
def is_compiletime_const(self) -> bool:
|
||||
if isinstance(self.nodes[0], SymbolName):
|
||||
symdef = self.nodes[0].my_scope().lookup(self.nodes[0].name)
|
||||
if isinstance(symdef, BuiltinFunction):
|
||||
@ -756,7 +762,7 @@ class SubCall(Expression):
|
||||
if isinstance(symdef, BuiltinFunction):
|
||||
arguments = [a.nodes[0].const_value() for a in self.nodes[2].nodes]
|
||||
return symdef.func(*arguments)
|
||||
raise TypeError("subroutine call is not a constant value", self)
|
||||
raise NotCompiletimeConstantError("subroutine call is not a constant value", self)
|
||||
|
||||
|
||||
@attr.s(cmp=False, slots=True, repr=False)
|
||||
@ -779,13 +785,10 @@ class VarDef(AstNode):
|
||||
self.nodes[0] = value
|
||||
else:
|
||||
self.nodes.append(value)
|
||||
if isinstance(value, ExpressionWithOperator):
|
||||
# an expression in a vardef should evaluate to a compile-time constant:
|
||||
value.must_be_constant = True
|
||||
|
||||
def const_value(self) -> Union[int, float, bool, str]:
|
||||
if self.vartype != VarType.CONST:
|
||||
raise TypeError("not a constant value", self)
|
||||
raise NotCompiletimeConstantError("not a constant value", self)
|
||||
if self.nodes and isinstance(self.nodes[0], Expression):
|
||||
return self.nodes[0].const_value()
|
||||
raise ValueError("no value", self)
|
||||
@ -841,59 +844,14 @@ class Return(AstNode):
|
||||
return self.nodes[2] if len(self.nodes) >= 3 else None # type: ignore
|
||||
|
||||
|
||||
@attr.s(cmp=False, slots=True, repr=False)
|
||||
class AssignmentTargets(AstNode):
|
||||
# a list of one or more assignment targets (Register, SymbolName, or Dereference).
|
||||
nodes = attr.ib(type=list, init=True) # requires nodes in __init__
|
||||
|
||||
def has_memvalue(self) -> bool:
|
||||
for t in self.nodes:
|
||||
if isinstance(t, Dereference):
|
||||
return True
|
||||
if isinstance(t, SymbolName):
|
||||
symdef = self.my_scope().lookup(t.name)
|
||||
if isinstance(symdef, VarDef) and symdef.vartype == VarType.MEMORY:
|
||||
return True
|
||||
return False
|
||||
|
||||
def same_targets(self, other: 'AssignmentTargets') -> bool:
|
||||
if len(self.nodes) != len(other.nodes):
|
||||
return False
|
||||
# @todo be able to compare targets in different order as well (sort them)
|
||||
for t1, t2 in zip(self.nodes, other.nodes):
|
||||
if type(t1) is not type(t2):
|
||||
return False
|
||||
if isinstance(t1, Register):
|
||||
if t1 != t2: # __eq__ is defined
|
||||
return False
|
||||
elif isinstance(t1, SymbolName):
|
||||
if t1.name != t2.name:
|
||||
return False
|
||||
elif isinstance(t1, Dereference):
|
||||
if t1.size != t2.size or t1.datatype != t2.datatype:
|
||||
return False
|
||||
op1, op2 = t1.operand, t2.operand
|
||||
if type(op1) is not type(op2):
|
||||
return False
|
||||
if isinstance(op1, SymbolName):
|
||||
if op1.name != op2.name:
|
||||
return False
|
||||
else:
|
||||
if op1 != op2:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@attr.s(cmp=False, slots=True, repr=False)
|
||||
class Assignment(AstNode):
|
||||
# can be single- or multi-assignment
|
||||
# has two subnodes: left (=AssignmentTargets) and right (=Expression,
|
||||
# has two subnodes: left (=Register/SymbolName/Dereference ) and right (=Expression,
|
||||
# or another Assignment but those will be converted into multi assign)
|
||||
|
||||
@property
|
||||
def left(self) -> AssignmentTargets:
|
||||
def left(self) -> Union[Register, SymbolName, Dereference]:
|
||||
return self.nodes[0] # type: ignore
|
||||
|
||||
@property
|
||||
@ -905,6 +863,10 @@ class Assignment(AstNode):
|
||||
assert isinstance(rvalue, Expression)
|
||||
self.nodes[1] = rvalue
|
||||
|
||||
def mark_lhs(self):
|
||||
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
|
||||
self.nodes[0].is_lhs = True
|
||||
|
||||
|
||||
@attr.s(cmp=False, slots=True, repr=False)
|
||||
class AugAssignment(AstNode):
|
||||
@ -924,6 +886,10 @@ class AugAssignment(AstNode):
|
||||
assert isinstance(rvalue, Expression)
|
||||
self.nodes[1] = rvalue
|
||||
|
||||
def mark_lhs(self):
|
||||
# the expression on the left hand side should be marked LHS to avoid improper constant folding/replacement.
|
||||
self.nodes[0].is_lhs = True
|
||||
|
||||
|
||||
def datatype_of(targetnode: AstNode, scope: Scope) -> DataType:
|
||||
# tries to determine the DataType of an assignment target node
|
||||
@ -1512,8 +1478,9 @@ def p_assignment(p):
|
||||
| assignment_target IS assignment
|
||||
"""
|
||||
p[0] = Assignment(sourceref=_token_sref(p, 2))
|
||||
p[0].nodes.append(AssignmentTargets(nodes=[p[1]], sourceref=p[0].sourceref))
|
||||
p[0].nodes.append(p[1])
|
||||
p[0].nodes.append(p[3])
|
||||
p[0].mark_lhs()
|
||||
|
||||
|
||||
def p_aug_assignment(p):
|
||||
@ -1523,6 +1490,7 @@ def p_aug_assignment(p):
|
||||
p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2))
|
||||
p[0].nodes.append(p[1])
|
||||
p[0].nodes.append(p[3])
|
||||
p[0].mark_lhs()
|
||||
|
||||
|
||||
precedence = (
|
||||
|
@ -164,16 +164,14 @@ def test_block_nodes():
|
||||
assert len(sub2.scope.nodes) > 0
|
||||
|
||||
|
||||
test_source_2 = """
|
||||
def test_parser_2():
|
||||
src = """
|
||||
~ {
|
||||
999(1,2)
|
||||
[zz]()
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_parser_2():
|
||||
result = parse_source(test_source_2)
|
||||
result = parse_source(src)
|
||||
block = result.scope.nodes[0]
|
||||
call = block.scope.nodes[0]
|
||||
assert isinstance(call, SubCall)
|
||||
@ -187,7 +185,8 @@ def test_parser_2():
|
||||
assert call.target.operand.name == "zz"
|
||||
|
||||
|
||||
test_source_3 = """
|
||||
def test_typespec():
|
||||
src = """
|
||||
~ {
|
||||
[$c000.word] = 5
|
||||
[$c000 .byte] = 5
|
||||
@ -195,10 +194,7 @@ test_source_3 = """
|
||||
[AX .float] = 5
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_typespec():
|
||||
result = parse_source(test_source_3)
|
||||
result = parse_source(src)
|
||||
block = result.scope.nodes[0]
|
||||
assignment1, assignment2, assignment3, assignment4 = block.scope.nodes
|
||||
assert assignment1.right.value == 5
|
||||
@ -209,10 +205,10 @@ def test_typespec():
|
||||
assert len(assignment2.left.nodes) == 1
|
||||
assert len(assignment3.left.nodes) == 1
|
||||
assert len(assignment4.left.nodes) == 1
|
||||
t1 = assignment1.left.nodes[0]
|
||||
t2 = assignment2.left.nodes[0]
|
||||
t3 = assignment3.left.nodes[0]
|
||||
t4 = assignment4.left.nodes[0]
|
||||
t1 = assignment1.left
|
||||
t2 = assignment2.left
|
||||
t3 = assignment3.left
|
||||
t4 = assignment4.left
|
||||
assert isinstance(t1, Dereference)
|
||||
assert isinstance(t2, Dereference)
|
||||
assert isinstance(t3, Dereference)
|
||||
@ -235,7 +231,8 @@ def test_typespec():
|
||||
assert t4.size is None
|
||||
|
||||
|
||||
test_source_4 = """
|
||||
def test_char_string():
|
||||
src = """
|
||||
~ {
|
||||
var x1 = '@'
|
||||
var x2 = 'π'
|
||||
@ -245,10 +242,7 @@ test_source_4 = """
|
||||
A = 'abc'
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_char_string():
|
||||
result = parse_source(test_source_4)
|
||||
result = parse_source(src)
|
||||
block = result.scope.nodes[0]
|
||||
var1, var2, var3, assgn1, assgn2, assgn3, = block.scope.nodes
|
||||
assert var1.value.value == '@'
|
||||
@ -260,7 +254,8 @@ def test_char_string():
|
||||
# note: the actual one-charactor-to-bytevalue conversion is done at the very latest, when issuing an assignment statement
|
||||
|
||||
|
||||
test_source_5 = """
|
||||
def test_boolean_int():
|
||||
src = """
|
||||
~ {
|
||||
var x1 = true
|
||||
var x2 = false
|
||||
@ -268,10 +263,7 @@ test_source_5 = """
|
||||
A = false
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def test_boolean_int():
|
||||
result = parse_source(test_source_5)
|
||||
result = parse_source(src)
|
||||
block = result.scope.nodes[0]
|
||||
var1, var2, assgn1, assgn2, = block.scope.nodes
|
||||
assert type(var1.value.value) is int and var1.value.value == 1
|
||||
@ -381,7 +373,7 @@ def test_const_numeric_expressions():
|
||||
result.scope.define_builtin_functions()
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert all(x.is_compile_constant() for x in e)
|
||||
assert all(x.is_compiletime_const() for x in e)
|
||||
assert e[0].const_value() == 15 # 1+2+3+4+5
|
||||
assert e[1].const_value() == 13 # 1+2*5+2
|
||||
assert e[2].const_value() == 21 # (1+2)*(5+2)
|
||||
@ -415,7 +407,7 @@ def test_const_logic_expressions():
|
||||
result = parse_source(src)
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert all(x.is_compile_constant() for x in e)
|
||||
assert all(x.is_compiletime_const() for x in e)
|
||||
assert e[0].const_value() == True
|
||||
assert e[1].const_value() == False
|
||||
assert e[2].const_value() == True
|
||||
@ -441,12 +433,12 @@ def test_const_other_expressions():
|
||||
result.scope.define_builtin_functions()
|
||||
assignments = list(result.all_nodes(Assignment))
|
||||
e = [a.nodes[1] for a in assignments]
|
||||
assert e[0].is_compile_constant()
|
||||
assert e[0].is_compiletime_const()
|
||||
assert e[0].const_value() == 0xc123
|
||||
assert not e[1].is_compile_constant()
|
||||
assert not e[1].is_compiletime_const()
|
||||
with pytest.raises(TypeError):
|
||||
e[1].const_value()
|
||||
assert not e[2].is_compile_constant()
|
||||
assert not e[2].is_compiletime_const()
|
||||
with pytest.raises(TypeError):
|
||||
e[2].const_value()
|
||||
|
||||
@ -495,3 +487,67 @@ def test_vdef_const_folds():
|
||||
assert vd[2].datatype == DataType.BYTE
|
||||
assert isinstance(vd[2].value, LiteralValue)
|
||||
assert vd[2].value.value == 369
|
||||
|
||||
|
||||
def test_vdef_const_expressions():
|
||||
src = """
|
||||
~ {
|
||||
var bvar = 99
|
||||
var .float fvar = sin(1.2-0.3)
|
||||
var .float flt2 = -9.87e-6
|
||||
|
||||
bvar ++
|
||||
fvar ++
|
||||
flt2 ++
|
||||
bvar += 2+2
|
||||
fvar += 2+3
|
||||
flt2 += 2+4
|
||||
bvar = 2+5
|
||||
fvar = 2+6
|
||||
flt2 = 2+7
|
||||
}
|
||||
"""
|
||||
result = parse_source(src)
|
||||
if isinstance(result, Module):
|
||||
result.scope.define_builtin_functions()
|
||||
cf = ConstantFold(result)
|
||||
cf.fold_constants()
|
||||
vd = list(result.all_nodes(VarDef))
|
||||
assert len(vd)==3
|
||||
assert vd[0].name == "bvar"
|
||||
assert isinstance(vd[0].value, LiteralValue)
|
||||
assert vd[0].value.value == 99
|
||||
assert vd[1].name == "fvar"
|
||||
assert isinstance(vd[1].value, LiteralValue)
|
||||
assert type(vd[1].value.value) is float
|
||||
assert math.isclose(vd[1].value.value, math.sin(0.9))
|
||||
assert vd[2].name == "flt2"
|
||||
assert isinstance(vd[2].value, LiteralValue)
|
||||
assert math.isclose(-9.87e-6, vd[2].value.value)
|
||||
# test incrdecr assignment target
|
||||
nodes = list(result.all_nodes(IncrDecr))
|
||||
assert len(nodes) == 3
|
||||
assert isinstance(nodes[0].target, SymbolName)
|
||||
assert nodes[0].target.name == "bvar"
|
||||
assert isinstance(nodes[1].target, SymbolName)
|
||||
assert nodes[1].target.name == "fvar"
|
||||
assert isinstance(nodes[2].target, SymbolName)
|
||||
assert nodes[2].target.name == "flt2"
|
||||
# test augassign assignment target
|
||||
nodes = list(result.all_nodes(AugAssignment))
|
||||
assert len(nodes) == 3
|
||||
assert isinstance(nodes[0].left, SymbolName)
|
||||
assert nodes[0].left.name == "bvar"
|
||||
assert isinstance(nodes[1].left, SymbolName)
|
||||
assert nodes[1].left.name == "fvar"
|
||||
assert isinstance(nodes[2].left, SymbolName)
|
||||
assert nodes[2].left.name == "flt2"
|
||||
# test assign assignment target
|
||||
nodes = list(result.all_nodes(Assignment))
|
||||
assert len(nodes) == 3
|
||||
assert isinstance(nodes[0].left, SymbolName)
|
||||
assert nodes[0].left.name == "bvar"
|
||||
assert isinstance(nodes[1].left, SymbolName)
|
||||
assert nodes[1].left.name == "fvar"
|
||||
assert isinstance(nodes[2].left, SymbolName)
|
||||
assert nodes[2].left.name == "flt2"
|
||||
|
@ -3,13 +3,6 @@ from il65.datatypes import DataType
|
||||
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError
|
||||
from il65.plylex import SourceRef
|
||||
|
||||
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
|
||||
# name = attr.ib(type=str)
|
||||
# vartype = attr.ib()
|
||||
# datatype = attr.ib()
|
||||
# size = attr.ib(type=list, default=None)
|
||||
# zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
|
||||
|
||||
|
||||
def test_creation():
|
||||
sref = SourceRef("test", 1, 1)
|
||||
@ -62,10 +55,8 @@ def test_set_value():
|
||||
assert v.value.value == "hello"
|
||||
e = ExpressionWithOperator(operator="-", sourceref=sref)
|
||||
e.left = LiteralValue(value=42, sourceref=sref)
|
||||
assert not e.must_be_constant
|
||||
v.value = e
|
||||
assert v.value is e
|
||||
assert e.must_be_constant
|
||||
|
||||
|
||||
def test_const_value():
|
||||
|
Loading…
x
Reference in New Issue
Block a user