more explicit use of Expression, fixed some optimizations

This commit is contained in:
Irmen de Jong 2018-01-29 22:34:28 +01:00
parent f82ceab969
commit 4d70e3d42f
7 changed files with 206 additions and 167 deletions

View File

@ -192,7 +192,7 @@ class PlyParser:
lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes) lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes)
if len(lvalue_types) == 1: if len(lvalue_types) == 1:
_, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) _, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref)
if isinstance(newright, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall)): if isinstance(newright, Expression):
node.right = newright # type: ignore node.right = newright # type: ignore
else: else:
raise TypeError("invalid coerced constant type", newright) raise TypeError("invalid coerced constant type", newright)
@ -322,7 +322,7 @@ class PlyParser:
return return
elif isinstance(expr, SubCall): elif isinstance(expr, SubCall):
self._get_subroutine_usages_from_subcall(usages, expr, parent_scope) self._get_subroutine_usages_from_subcall(usages, expr, parent_scope)
elif isinstance(expr, Expression): elif isinstance(expr, ExpressionWithOperator):
self._get_subroutine_usages_from_expression(usages, expr.left, parent_scope) self._get_subroutine_usages_from_expression(usages, expr.left, parent_scope)
self._get_subroutine_usages_from_expression(usages, expr.right, parent_scope) self._get_subroutine_usages_from_expression(usages, expr.right, parent_scope)
elif isinstance(expr, LiteralValue): elif isinstance(expr, LiteralValue):

View File

@ -36,7 +36,7 @@ def generate_aug_assignment(out: Callable, stmt: AugAssignment, scope: Scope) ->
elif isinstance(rvalue, SymbolName): elif isinstance(rvalue, SymbolName):
symdef = scope.lookup(rvalue.name) symdef = scope.lookup(rvalue.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST and symdef.datatype.isinteger(): if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST and symdef.datatype.isinteger():
if 0 <= symdef.value.const_num_val() <= 255: if 0 <= symdef.value.const_value() <= 255: # type: ignore
_generate_aug_reg_constant_int(out, lvalue, stmt.operator, 0, symdef.name, scope) _generate_aug_reg_constant_int(out, lvalue, stmt.operator, 0, symdef.name, scope)
else: else:
raise CodeError("assignment value must be 0..255", rvalue) raise CodeError("assignment value must be 0..255", rvalue)
@ -136,7 +136,7 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
else: else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo ^=.word raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo ^=.word
elif operator == ">>=": elif operator == ">>=":
if rvalue.value > 0: if rvalue > 0:
def shifts_A(times: int) -> None: def shifts_A(times: int) -> None:
if times >= 8: if times >= 8:
out("\vlda #0") out("\vlda #0")
@ -144,21 +144,21 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
for _ in range(min(8, times)): for _ in range(min(8, times)):
out("\vlsr a") out("\vlsr a")
if lvalue.name == "A": if lvalue.name == "A":
shifts_A(rvalue.value) shifts_A(rvalue)
elif lvalue.name == "X": elif lvalue.name == "X":
with preserving_registers({'A'}, scope, out): with preserving_registers({'A'}, scope, out):
out("\vtxa") out("\vtxa")
shifts_A(rvalue.value) shifts_A(rvalue)
out("\vtax") out("\vtax")
elif lvalue.name == "Y": elif lvalue.name == "Y":
with preserving_registers({'A'}, scope, out): with preserving_registers({'A'}, scope, out):
out("\vtya") out("\vtya")
shifts_A(rvalue.value) shifts_A(rvalue)
out("\vtay") out("\vtay")
else: else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo >>=.word raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo >>=.word
elif operator == "<<=": elif operator == "<<=":
if rvalue.value > 0: if rvalue > 0:
def shifts_A(times: int) -> None: def shifts_A(times: int) -> None:
if times >= 8: if times >= 8:
out("\vlda #0") out("\vlda #0")
@ -166,16 +166,16 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
for _ in range(min(8, times)): for _ in range(min(8, times)):
out("\vasl a") out("\vasl a")
if lvalue.name == "A": if lvalue.name == "A":
shifts_A(rvalue.value) shifts_A(rvalue)
elif lvalue.name == "X": elif lvalue.name == "X":
with preserving_registers({'A'}, scope, out): with preserving_registers({'A'}, scope, out):
out("\vtxa") out("\vtxa")
shifts_A(rvalue.value) shifts_A(rvalue)
out("\vtax") out("\vtax")
elif lvalue.name == "Y": elif lvalue.name == "Y":
with preserving_registers({'A'}, scope, out): with preserving_registers({'A'}, scope, out):
out("\vtya") out("\vtya")
shifts_A(rvalue.value) shifts_A(rvalue)
out("\vtay") out("\vtay")
else: else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo <<=.word raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo <<=.word

View File

@ -58,11 +58,14 @@ class Optimizer:
def constant_folding(self) -> None: def constant_folding(self) -> None:
for expression in self.module.all_nodes(Expression): for expression in self.module.all_nodes(Expression):
if isinstance(expression, LiteralValue):
continue
try: try:
evaluated = process_expression(expression, expression.sourceref) # type: ignore evaluated = process_expression(expression) # type: ignore
if evaluated is not expression: if evaluated is not expression:
# replace the node with the newly evaluated result # replace the node with the newly evaluated result
expression.parent.replace_node(expression, evaluated) expression.parent.replace_node(expression, evaluated)
self.optimizations_performed = True
except ParseError: except ParseError:
raise raise
except Exception as x: except Exception as x:
@ -148,12 +151,12 @@ class Optimizer:
for assignment in self.module.all_nodes(Assignment): for assignment in self.module.all_nodes(Assignment):
if len(assignment.left.nodes) > 1: if len(assignment.left.nodes) > 1:
continue continue
if not isinstance(assignment.right, Expression) or assignment.right.unary: if not isinstance(assignment.right, ExpressionWithOperator) or assignment.right.unary:
continue continue
expr = assignment.right expr = assignment.right
if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
num_val = expr.right.const_num_val() num_val = expr.right.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign) assignment.my_scope().replace_node(assignment, aug_assign)
@ -162,13 +165,13 @@ class Optimizer:
if expr.operator not in ('+', '*', '|', '^'): # associative operators if expr.operator not in ('+', '*', '|', '^'): # associative operators
continue continue
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
num_val = expr.right.const_num_val() num_val = expr.right.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign) assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True self.optimizations_performed = True
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right): elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right):
num_val = expr.left.const_num_val() num_val = expr.left.const_value()
operator = expr.operator + '=' operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign) assignment.my_scope().replace_node(assignment, aug_assign)
@ -189,6 +192,7 @@ class Optimizer:
print_warning("{}: removed superfluous assignment".format(prev_node.sourceref)) print_warning("{}: removed superfluous assignment".format(prev_node.sourceref))
prev_node = node prev_node = node
@no_type_check
def optimize_assignments(self) -> None: def optimize_assignments(self) -> None:
# remove assignment statements that do nothing (A=A) # remove assignment statements that do nothing (A=A)
# remove augmented assignments that have no effect (x+=0, x-=0, x/=1, x//=1, x*=1) # remove augmented assignments that have no effect (x+=0, x-=0, x/=1, x//=1, x*=1)
@ -385,27 +389,27 @@ class Optimizer:
node.my_scope().nodes.remove(node) node.my_scope().nodes.remove(node)
def process_expression(expr: Expression, sourceref: SourceRef) -> Any: def process_expression(expr: Expression) -> Any:
# process/simplify all expressions (constant folding etc) # process/simplify all expressions (constant folding etc)
if expr.must_be_constant: if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
return process_constant_expression(expr, sourceref) return process_constant_expression(expr, expr.sourceref)
else: else:
return process_dynamic_expression(expr, sourceref) return process_dynamic_expression(expr, expr.sourceref)
def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue: def process_constant_expression(expr: Expression, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, (int, float, str, bool)): if isinstance(expr, LiteralValue):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName): elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef): if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY: if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value value = value.value
if isinstance(value, Expression): if isinstance(value, ExpressionWithOperator):
raise ExpressionEvaluationError("circular reference?", expr.sourceref) raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue): elif isinstance(value, LiteralValue):
return value return value
@ -452,8 +456,7 @@ def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else: else:
raise NotImplementedError("weird call target", expr.target) raise NotImplementedError("weird call target", expr.target)
elif not isinstance(expr, Expression): elif isinstance(expr, ExpressionWithOperator):
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary: if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref) expr.left = process_constant_expression(expr.left, left_sourceref)
@ -483,14 +486,16 @@ def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue
else: else:
raise ExpressionEvaluationError("constant literal value required on left, not {}" raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref) .format(expr.left.__class__.__name__), left_sourceref)
else:
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any: def process_dynamic_expression(expr: Expression, sourceref: SourceRef) -> Any:
# constant-fold a dynamic expression # constant-fold a dynamic expression
if isinstance(expr, (int, float, str, bool)): if isinstance(expr, LiteralValue):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName): elif isinstance(expr, SymbolName):
try: try:
return process_constant_expression(expr, sourceref) return process_constant_expression(expr, sourceref)
@ -514,8 +519,7 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
if isinstance(expr.operand, SymbolName): if isinstance(expr.operand, SymbolName):
check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref) check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref)
return expr return expr
elif not isinstance(expr, Expression): elif isinstance(expr, ExpressionWithOperator):
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary: if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref) expr.left = process_dynamic_expression(expr.left, left_sourceref)
@ -532,6 +536,8 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
return process_constant_expression(expr, sourceref) return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError: except ExpressionEvaluationError:
return expr return expr
else:
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
def optimize(mod: Module) -> None: def optimize(mod: Module) -> None:

View File

@ -77,10 +77,11 @@ class AstNode:
def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]: def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]:
nodetypes = nodetypes or (AstNode, ) nodetypes = nodetypes or (AstNode, )
for node in list(self.nodes): child_nodes = list(self.nodes)
if isinstance(node, nodetypes): # type: ignore for node in child_nodes:
if isinstance(node, nodetypes):
yield node yield node
for node in self.nodes: for node in child_nodes:
if isinstance(node, AstNode): if isinstance(node, AstNode):
yield from node.all_nodes(*nodetypes) yield from node.all_nodes(*nodetypes)
@ -293,8 +294,19 @@ class Label(AstNode):
# no subnodes. # no subnodes.
@attr.s(cmp=False, slots=True, repr=False)
class Expression(AstNode):
# just a common base class for the nodes that are an expression themselves:
# ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference
def is_compile_constant(self) -> bool:
raise NotImplementedError("implement in subclass")
def const_value(self) -> Union[int, float, bool, str]:
raise NotImplementedError("implement in subclass")
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True)
class Register(AstNode): class Register(Expression):
name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS)) name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS))
datatype = attr.ib(type=DataType, init=False) datatype = attr.ib(type=DataType, init=False)
# no subnodes. # no subnodes.
@ -320,6 +332,12 @@ class Register(AstNode):
return NotImplemented return NotImplemented
return self.name < other.name return self.name < other.name
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("register doesn't have a constant numeric value", self)
@attr.s(cmp=False) @attr.s(cmp=False)
class PreserveRegs(AstNode): class PreserveRegs(AstNode):
@ -386,51 +404,57 @@ class Subroutine(AstNode):
@attr.s(cmp=True, slots=True, repr=False) @attr.s(cmp=True, slots=True, repr=False)
class LiteralValue(AstNode): class LiteralValue(Expression):
# no subnodes. # no subnodes.
value = attr.ib() value = attr.ib()
def __repr__(self) -> str: def __repr__(self) -> str:
return "<LiteralValue value={!r} at {}>".format(self.value, self.sourceref) return "<LiteralValue value={!r} at {}>".format(self.value, self.sourceref)
def const_num_val(self) -> Union[int, float]: def const_value(self) -> Union[int, float, bool, str]:
if isinstance(self.value, (int, float)):
return self.value return self.value
raise TypeError("literal value is not numeric", self)
def is_compile_constant(self) -> bool:
return True
@attr.s(cmp=False) @attr.s(cmp=False)
class AddressOf(AstNode): class AddressOf(Expression):
# no subnodes. # no subnodes.
name = attr.ib(type=str) name = attr.ib(type=str)
def const_num_val(self) -> Union[int, float]: def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef): if isinstance(symdef, VarDef):
if symdef.zp_address is not None: if symdef.zp_address is not None:
return symdef.zp_address return symdef.zp_address
if symdef.vartype == VarType.MEMORY: if symdef.vartype == VarType.MEMORY:
return symdef.value.const_num_val() return symdef.value.const_value()
raise TypeError("can only take constant address of a memory mapped variable", self) raise TypeError("can only take constant address of a memory mapped variable", self)
raise TypeError("should be a vardef to be able to take its address", self) raise TypeError("should be a vardef to be able to take its address", self)
@attr.s(cmp=False, slots=True) @attr.s(cmp=False, slots=True)
class SymbolName(AstNode): class SymbolName(Expression):
# no subnodes. # no subnodes.
name = attr.ib(type=str) name = attr.ib(type=str)
def const_num_val(self) -> Union[int, float]: def is_compile_constant(self) -> bool:
symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST
def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name) symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
if symdef.datatype.isnumeric(): return symdef.const_value()
return symdef.const_num_val() raise TypeError("should be a const vardef to be able to take its constant numeric value", self)
raise TypeError("not a constant value", self)
raise TypeError("should be a vardef to be able to take its constant numeric value", self)
@attr.s(cmp=False) @attr.s(cmp=False)
class Dereference(AstNode): class Dereference(Expression):
# one subnode: operand (SymbolName, int or register name) # one subnode: operand (SymbolName, int or register name)
datatype = attr.ib() datatype = attr.ib()
size = attr.ib(type=int, default=None) size = attr.ib(type=int, default=None)
@ -452,6 +476,12 @@ class Dereference(AstNode):
raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref) raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref)
self.datatype = self.datatype.to_enum() self.datatype = self.datatype.to_enum()
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("dereference is not a constant numeric value")
@attr.s(cmp=False) @attr.s(cmp=False)
class IncrDecr(AstNode): class IncrDecr(AstNode):
@ -483,12 +513,12 @@ class IncrDecr(AstNode):
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
class Expression(AstNode): class ExpressionWithOperator(Expression):
left = attr.ib() left = attr.ib() # type: Expression
operator = attr.ib(type=str) operator = attr.ib(type=str)
right = attr.ib() right = attr.ib() # type: Expression
unary = attr.ib(type=bool, default=False) unary = attr.ib(type=bool, default=False)
# when evaluating an expression, does it have to be a constant value? # when evaluating the expression, does it have to be a compile-time constant value?
must_be_constant = attr.ib(type=bool, init=False, default=False) must_be_constant = attr.ib(type=bool, init=False, default=False)
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -496,9 +526,12 @@ class Expression(AstNode):
if self.operator == "mod": if self.operator == "mod":
self.operator = "%" # change it back to the more common '%' self.operator = "%" # change it back to the more common '%'
def const_num_val(self) -> Union[int, float]: def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("an expression is not a constant", self) raise TypeError("an expression is not a constant", self)
def is_compile_constant(self) -> bool:
return False
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue: def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
# make sure the lvalue and rvalue are primitives, and the operator is allowed # make sure the lvalue and rvalue are primitives, and the operator is allowed
assert isinstance(self.left, LiteralValue) assert isinstance(self.left, LiteralValue)
@ -513,19 +546,6 @@ class Expression(AstNode):
except Exception as x: except Exception as x:
raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None
def print_tree(self) -> None:
def tree(expr: Any, level: int) -> str:
indent = " "*level
if not isinstance(expr, Expression):
return indent + str(expr) + "\n"
if expr.unary:
return indent + "{}{}".format(expr.operator, tree(expr.left, level+1))
else:
return indent + "{}".format(tree(expr.left, level+1)) + \
indent + str(expr.operator) + "\n" + \
indent + "{}".format(tree(expr.right, level + 1))
print(tree(self, 0))
@attr.s(cmp=False, repr=False) @attr.s(cmp=False, repr=False)
class Goto(AstNode): class Goto(AstNode):
@ -537,7 +557,7 @@ class Goto(AstNode):
return self.nodes[0] # type: ignore return self.nodes[0] # type: ignore
@property @property
def condition(self) -> Expression: def condition(self) -> Optional[Expression]:
return self.nodes[1] if len(self.nodes) == 2 else None # type: ignore return self.nodes[1] if len(self.nodes) == 2 else None # type: ignore
@ -558,7 +578,7 @@ class CallArguments(AstNode):
@attr.s(cmp=False, repr=False) @attr.s(cmp=False, repr=False)
class SubCall(AstNode): class SubCall(Expression):
# has three subnodes: # has three subnodes:
# 0: target (Symbolname, int, or Dereference), # 0: target (Symbolname, int, or Dereference),
# 1: preserve_regs (PreserveRegs) # 1: preserve_regs (PreserveRegs)
@ -576,10 +596,16 @@ class SubCall(AstNode):
def arguments(self) -> CallArguments: def arguments(self) -> CallArguments:
return self.nodes[2] # type: ignore return self.nodes[2] # type: ignore
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("subroutine call is not a constant value", self)
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
class VarDef(AstNode): class VarDef(AstNode):
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.). # zero or one subnode: value (Expression).
name = attr.ib(type=str) name = attr.ib(type=str)
vartype = attr.ib() vartype = attr.ib()
datatype = attr.ib() datatype = attr.ib()
@ -587,29 +613,26 @@ class VarDef(AstNode):
zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
@property @property
def value(self) -> Union[LiteralValue, Expression, AddressOf, SymbolName]: def value(self) -> Expression:
return self.nodes[0] if self.nodes else None # type: ignore return self.nodes[0] if self.nodes else None # type: ignore
@value.setter @value.setter
def value(self, value: Union[LiteralValue, Expression, AddressOf, SymbolName]) -> None: def value(self, value: Expression) -> None:
assert isinstance(value, (LiteralValue, Expression, AddressOf, SymbolName)) assert isinstance(value, Expression)
if self.nodes: if self.nodes:
self.nodes[0] = value self.nodes[0] = value
else: else:
self.nodes.append(value) self.nodes.append(value)
# if the value is an expression, mark it as a *constant* expression here if isinstance(value, ExpressionWithOperator):
if isinstance(value, Expression): # an expression in a vardef should evaluate to a compile-time constant:
value.must_be_constant = True value.must_be_constant = True
def const_num_val(self) -> Union[int, float]: def const_value(self) -> Union[int, float, bool, str]:
if self.vartype != VarType.CONST: if self.vartype != VarType.CONST:
raise TypeError("not a constant value", self) raise TypeError("not a constant value", self)
if self.datatype.isnumeric(): if self.nodes and isinstance(self.nodes[0], Expression):
if self.nodes: return self.nodes[0].const_value()
return self.nodes[0].const_num_val() # type: ignore
raise ValueError("no value", self) raise ValueError("no value", self)
else:
raise TypeError("not numeric", self)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# convert vartype to enum # convert vartype to enum
@ -647,15 +670,15 @@ class VarDef(AstNode):
class Return(AstNode): class Return(AstNode):
# one, two or three subnodes: value_A, value_X, value_Y (all three Expression) # one, two or three subnodes: value_A, value_X, value_Y (all three Expression)
@property @property
def value_A(self) -> Expression: def value_A(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore return self.nodes[0] if self.nodes else None # type: ignore
@property @property
def value_X(self) -> Expression: def value_X(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore return self.nodes[0] if self.nodes else None # type: ignore
@property @property
def value_Y(self) -> Expression: def value_Y(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore return self.nodes[0] if self.nodes else None # type: ignore
@ -709,20 +732,20 @@ class AssignmentTargets(AstNode):
@attr.s(cmp=False, slots=True, repr=False) @attr.s(cmp=False, slots=True, repr=False)
class Assignment(AstNode): class Assignment(AstNode):
# can be single- or multi-assignment # can be single- or multi-assignment
# has two subnodes: left (=AssignmentTargets) and right (=reg/literal/expr # has two subnodes: left (=AssignmentTargets) and right (=Expression,
# or another Assignment but those will be converted to multi assign) # or another Assignment but those will be converted into multi assign)
@property @property
def left(self) -> AssignmentTargets: def left(self) -> AssignmentTargets:
return self.nodes[0] # type: ignore return self.nodes[0] # type: ignore
@property @property
def right(self) -> Union[Register, LiteralValue, Expression]: def right(self) -> Expression:
return self.nodes[1] # type: ignore return self.nodes[1] # type: ignore
@right.setter @right.setter
def right(self, rvalue: Union[Register, LiteralValue, Expression, Dereference, SymbolName, SubCall]) -> None: def right(self, rvalue: Expression) -> None:
assert isinstance(rvalue, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall)) assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue self.nodes[1] = rvalue
@ -789,7 +812,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT):
if type(value.value) not in (int, float): if type(value.value) not in (int, float):
raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref) raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref)
elif isinstance(value, (Expression, SubCall)): elif isinstance(value, (ExpressionWithOperator, SubCall)):
return False, value return False, value
elif isinstance(value, SymbolName): elif isinstance(value, SymbolName):
symboldef = value.my_scope().lookup(value.name) symboldef = value.my_scope().lookup(value.name)
@ -797,7 +820,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
return True, symboldef.value return True, symboldef.value
elif isinstance(value, AddressOf): elif isinstance(value, AddressOf):
try: try:
address = value.const_num_val() address = value.const_value()
return True, LiteralValue(value=address, sourceref=value.sourceref) # type: ignore return True, LiteralValue(value=address, sourceref=value.sourceref) # type: ignore
except TypeError: except TypeError:
return False, value return False, value
@ -1357,14 +1380,14 @@ def p_expression(p):
| expression EQUALS expression | expression EQUALS expression
| expression NOTEQUALS expression | expression NOTEQUALS expression
""" """
p[0] = Expression(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) p[0] = ExpressionWithOperator(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2))
def p_expression_uminus(p): def p_expression_uminus(p):
""" """
expression : '-' expression %prec UNARY_MINUS expression : '-' expression %prec UNARY_MINUS
""" """
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_expression_addressof(p): def p_expression_addressof(p):
@ -1378,14 +1401,14 @@ def p_unary_expression_bitinvert(p):
""" """
expression : BITINVERT expression expression : BITINVERT expression
""" """
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_unary_expression_logicnot(p): def p_unary_expression_logicnot(p):
""" """
expression : LOGICNOT expression expression : LOGICNOT expression
""" """
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_expression_group(p): def p_expression_group(p):

View File

@ -1,6 +1,7 @@
import pytest
from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef
from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, Return, Scope, \ from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \
VarDef, Expression, LiteralValue, Label, SubCall, Dereference VarDef, Expression, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference
from il65.datatypes import DataType from il65.datatypes import DataType
@ -127,7 +128,7 @@ def test_parser():
assert block.name == "block" assert block.name == "block"
bool_vdef = block.scope.nodes[1] bool_vdef = block.scope.nodes[1]
assert isinstance(bool_vdef, VarDef) assert isinstance(bool_vdef, VarDef)
assert isinstance(bool_vdef.value, Expression) assert isinstance(bool_vdef.value, ExpressionWithOperator)
assert isinstance(bool_vdef.value.right, LiteralValue) assert isinstance(bool_vdef.value.right, LiteralValue)
assert isinstance(bool_vdef.value.right.value, int) assert isinstance(bool_vdef.value.right.value, int)
assert bool_vdef.value.right.value == 1 assert bool_vdef.value.right.value == 1
@ -283,3 +284,11 @@ def test_boolean_int():
assert type(var2.value.value) is int and var2.value.value == 0 assert type(var2.value.value) is int and var2.value.value == 0
assert type(assgn1.right.value) is int and assgn1.right.value == 1 assert type(assgn1.right.value) is int and assgn1.right.value == 1
assert type(assgn2.right.value) is int and assgn2.right.value == 0 assert type(assgn2.right.value) is int and assgn2.right.value == 0
def test_incrdecr():
sref = SourceRef("test", 1, 1)
with pytest.raises(ValueError):
IncrDecr(operator="??", sourceref=sref)
i = IncrDecr(operator="++", sourceref=sref)
assert i.howmuch == 1

View File

@ -1,6 +1,6 @@
import pytest import pytest
from il65.datatypes import DataType from il65.datatypes import DataType
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, Expression, Scope, AddressOf, SymbolName, UndefinedSymbolError from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError
from il65.plylex import SourceRef from il65.plylex import SourceRef
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.). # zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
@ -61,14 +61,14 @@ def test_set_value():
assert v.value is None assert v.value is None
v.value = LiteralValue(value="hello", sourceref=sref) v.value = LiteralValue(value="hello", sourceref=sref)
assert v.value.value == "hello" assert v.value.value == "hello"
e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
assert not e.must_be_constant assert not e.must_be_constant
v.value = e v.value = e
assert v.value is e assert v.value is e
assert e.must_be_constant assert e.must_be_constant
def test_const_num_val(): def test_const_value():
sref = SourceRef("test", 1, 1) sref = SourceRef("test", 1, 1)
scope = Scope(nodes=[], level="block", sourceref=sref) scope = Scope(nodes=[], level="block", sourceref=sref)
vardef = VarDef(name="constvar", vartype="const", datatype=None, sourceref=sref) vardef = VarDef(name="constvar", vartype="const", datatype=None, sourceref=sref)
@ -82,37 +82,37 @@ def test_const_num_val():
scope.add_node(vardef) scope.add_node(vardef)
v = VarDef(name="v1", vartype="var", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) v = VarDef(name="v1", vartype="var", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
with pytest.raises(TypeError): with pytest.raises(TypeError):
v.const_num_val() v.const_value()
v = VarDef(name="v1", vartype="memory", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) v = VarDef(name="v1", vartype="memory", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
with pytest.raises(TypeError): with pytest.raises(TypeError):
v.const_num_val() v.const_value()
v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
assert v.const_num_val() == 0 assert v.const_value() == 0
v.value = LiteralValue(value=42, sourceref=sref) v.value = LiteralValue(value=42, sourceref=sref)
assert v.const_num_val() == 42 assert v.const_value() == 42
v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="float", sourceref=sref), sourceref=sref) v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="float", sourceref=sref), sourceref=sref)
assert v.const_num_val() == 0 assert v.const_value() == 0
v.value = LiteralValue(value=42.9988, sourceref=sref) v.value = LiteralValue(value=42.9988, sourceref=sref)
assert v.const_num_val() == 42.9988 assert v.const_value() == 42.9988
e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
v.value = e v.value = e
with pytest.raises(TypeError): with pytest.raises(TypeError):
v.const_num_val() v.const_value()
s = SymbolName(name="unexisting", sourceref=sref) s = SymbolName(name="unexisting", sourceref=sref)
s.parent = scope s.parent = scope
v.value = s v.value = s
with pytest.raises(UndefinedSymbolError): with pytest.raises(UndefinedSymbolError):
v.const_num_val() v.const_value()
s = SymbolName(name="constvar", sourceref=sref) s = SymbolName(name="constvar", sourceref=sref)
s.parent = scope s.parent = scope
v.value = s v.value = s
assert v.const_num_val() == 43 assert v.const_value() == 43
a = AddressOf(name="varvar", sourceref=sref) a = AddressOf(name="varvar", sourceref=sref)
a.parent = scope a.parent = scope
v.value = a v.value = a
with pytest.raises(TypeError): with pytest.raises(TypeError):
v.const_num_val() v.const_value()
a = AddressOf(name="memvar", sourceref=sref) a = AddressOf(name="memvar", sourceref=sref)
a.parent = scope a.parent = scope
v.value = a v.value = a
assert v.const_num_val() == 45 assert v.const_value() == 45

View File

@ -5,16 +5,15 @@
start: start:
Y+=10
Y-=5 Y-=5
Y-=8 Y-=8
Y-- Y--
flt+=2 ;flt+=2 ; @todo implement on float
flt+=2 ;flt+=2
flt+=2 ;flt+=2
flt+=2 ;flt+=2
flt+=2 ;flt+=2
X=0 X=0
X+=5 X+=5
@ -27,7 +26,8 @@ start:
X=24|X X=24|X
X=X^66 X=X^66
X+=250 X+=250
X+=5 X=5+2+3+4
X+=5+2+3+4
X-=100 X-=100
X-=50 X-=50
X-=5 X-=5
@ -52,6 +52,7 @@ start:
;flt += 10.1 ;flt += 10.1
;flt += 100.1 ;flt += 100.1
;flt += 1000.1 ;flt += 1000.1
;flt *= 2.34
return 44 return 44
} }