more explicit use of Expression, fixed some optimizations

This commit is contained in:
Irmen de Jong 2018-01-29 22:34:28 +01:00
parent f82ceab969
commit 4d70e3d42f
7 changed files with 206 additions and 167 deletions

View File

@ -192,7 +192,7 @@ class PlyParser:
lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes)
if len(lvalue_types) == 1:
_, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref)
if isinstance(newright, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall)):
if isinstance(newright, Expression):
node.right = newright # type: ignore
else:
raise TypeError("invalid coerced constant type", newright)
@ -322,7 +322,7 @@ class PlyParser:
return
elif isinstance(expr, SubCall):
self._get_subroutine_usages_from_subcall(usages, expr, parent_scope)
elif isinstance(expr, Expression):
elif isinstance(expr, ExpressionWithOperator):
self._get_subroutine_usages_from_expression(usages, expr.left, parent_scope)
self._get_subroutine_usages_from_expression(usages, expr.right, parent_scope)
elif isinstance(expr, LiteralValue):

View File

@ -36,7 +36,7 @@ def generate_aug_assignment(out: Callable, stmt: AugAssignment, scope: Scope) ->
elif isinstance(rvalue, SymbolName):
symdef = scope.lookup(rvalue.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST and symdef.datatype.isinteger():
if 0 <= symdef.value.const_num_val() <= 255:
if 0 <= symdef.value.const_value() <= 255: # type: ignore
_generate_aug_reg_constant_int(out, lvalue, stmt.operator, 0, symdef.name, scope)
else:
raise CodeError("assignment value must be 0..255", rvalue)
@ -136,7 +136,7 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo ^=.word
elif operator == ">>=":
if rvalue.value > 0:
if rvalue > 0:
def shifts_A(times: int) -> None:
if times >= 8:
out("\vlda #0")
@ -144,21 +144,21 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
for _ in range(min(8, times)):
out("\vlsr a")
if lvalue.name == "A":
shifts_A(rvalue.value)
shifts_A(rvalue)
elif lvalue.name == "X":
with preserving_registers({'A'}, scope, out):
out("\vtxa")
shifts_A(rvalue.value)
shifts_A(rvalue)
out("\vtax")
elif lvalue.name == "Y":
with preserving_registers({'A'}, scope, out):
out("\vtya")
shifts_A(rvalue.value)
shifts_A(rvalue)
out("\vtay")
else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo >>=.word
elif operator == "<<=":
if rvalue.value > 0:
if rvalue > 0:
def shifts_A(times: int) -> None:
if times >= 8:
out("\vlda #0")
@ -166,16 +166,16 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st
for _ in range(min(8, times)):
out("\vasl a")
if lvalue.name == "A":
shifts_A(rvalue.value)
shifts_A(rvalue)
elif lvalue.name == "X":
with preserving_registers({'A'}, scope, out):
out("\vtxa")
shifts_A(rvalue.value)
shifts_A(rvalue)
out("\vtax")
elif lvalue.name == "Y":
with preserving_registers({'A'}, scope, out):
out("\vtya")
shifts_A(rvalue.value)
shifts_A(rvalue)
out("\vtay")
else:
raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo <<=.word

View File

@ -58,11 +58,14 @@ class Optimizer:
def constant_folding(self) -> None:
for expression in self.module.all_nodes(Expression):
if isinstance(expression, LiteralValue):
continue
try:
evaluated = process_expression(expression, expression.sourceref) # type: ignore
evaluated = process_expression(expression) # type: ignore
if evaluated is not expression:
# replace the node with the newly evaluated result
expression.parent.replace_node(expression, evaluated)
self.optimizations_performed = True
except ParseError:
raise
except Exception as x:
@ -148,12 +151,12 @@ class Optimizer:
for assignment in self.module.all_nodes(Assignment):
if len(assignment.left.nodes) > 1:
continue
if not isinstance(assignment.right, Expression) or assignment.right.unary:
if not isinstance(assignment.right, ExpressionWithOperator) or assignment.right.unary:
continue
expr = assignment.right
if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
num_val = expr.right.const_num_val()
num_val = expr.right.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
@ -162,13 +165,13 @@ class Optimizer:
if expr.operator not in ('+', '*', '|', '^'): # associative operators
continue
if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left):
num_val = expr.right.const_num_val()
num_val = expr.right.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right):
num_val = expr.left.const_num_val()
num_val = expr.left.const_value()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
@ -189,6 +192,7 @@ class Optimizer:
print_warning("{}: removed superfluous assignment".format(prev_node.sourceref))
prev_node = node
@no_type_check
def optimize_assignments(self) -> None:
# remove assignment statements that do nothing (A=A)
# remove augmented assignments that have no effect (x+=0, x-=0, x/=1, x//=1, x*=1)
@ -385,27 +389,27 @@ class Optimizer:
node.my_scope().nodes.remove(node)
def process_expression(expr: Expression, sourceref: SourceRef) -> Any:
def process_expression(expr: Expression) -> Any:
# process/simplify all expressions (constant folding etc)
if expr.must_be_constant:
return process_constant_expression(expr, sourceref)
if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant:
return process_constant_expression(expr, expr.sourceref)
else:
return process_dynamic_expression(expr, sourceref)
return process_dynamic_expression(expr, expr.sourceref)
def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue:
def process_constant_expression(expr: Expression, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, Expression):
if isinstance(value, ExpressionWithOperator):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue):
return value
@ -452,45 +456,46 @@ def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else:
raise NotImplementedError("weird call target", expr.target)
elif not isinstance(expr, Expression):
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_constant_expression(expr.right, right_sourceref)
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_constant_expression(expr.right, right_sourceref)
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
else:
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
def process_dynamic_expression(expr: Expression, sourceref: SourceRef) -> Any:
# constant-fold a dynamic expression
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
if isinstance(expr, LiteralValue):
return expr
if expr.is_compile_constant():
return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore
elif isinstance(expr, SymbolName):
try:
return process_constant_expression(expr, sourceref)
@ -514,24 +519,25 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
if isinstance(expr.operand, SymbolName):
check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref)
return expr
elif not isinstance(expr, Expression):
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
elif isinstance(expr, ExpressionWithOperator):
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_dynamic_expression(expr.right, right_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_dynamic_expression(expr.right, right_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
def optimize(mod: Module) -> None:

View File

@ -77,10 +77,11 @@ class AstNode:
def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]:
nodetypes = nodetypes or (AstNode, )
for node in list(self.nodes):
if isinstance(node, nodetypes): # type: ignore
child_nodes = list(self.nodes)
for node in child_nodes:
if isinstance(node, nodetypes):
yield node
for node in self.nodes:
for node in child_nodes:
if isinstance(node, AstNode):
yield from node.all_nodes(*nodetypes)
@ -293,8 +294,19 @@ class Label(AstNode):
# no subnodes.
@attr.s(cmp=False, slots=True, repr=False)
class Expression(AstNode):
# just a common base class for the nodes that are an expression themselves:
# ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference
def is_compile_constant(self) -> bool:
raise NotImplementedError("implement in subclass")
def const_value(self) -> Union[int, float, bool, str]:
raise NotImplementedError("implement in subclass")
@attr.s(cmp=False, slots=True)
class Register(AstNode):
class Register(Expression):
name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS))
datatype = attr.ib(type=DataType, init=False)
# no subnodes.
@ -320,6 +332,12 @@ class Register(AstNode):
return NotImplemented
return self.name < other.name
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("register doesn't have a constant numeric value", self)
@attr.s(cmp=False)
class PreserveRegs(AstNode):
@ -386,51 +404,57 @@ class Subroutine(AstNode):
@attr.s(cmp=True, slots=True, repr=False)
class LiteralValue(AstNode):
class LiteralValue(Expression):
# no subnodes.
value = attr.ib()
def __repr__(self) -> str:
return "<LiteralValue value={!r} at {}>".format(self.value, self.sourceref)
def const_num_val(self) -> Union[int, float]:
if isinstance(self.value, (int, float)):
return self.value
raise TypeError("literal value is not numeric", self)
def const_value(self) -> Union[int, float, bool, str]:
return self.value
def is_compile_constant(self) -> bool:
return True
@attr.s(cmp=False)
class AddressOf(AstNode):
class AddressOf(Expression):
# no subnodes.
name = attr.ib(type=str)
def const_num_val(self) -> Union[int, float]:
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef):
if symdef.zp_address is not None:
return symdef.zp_address
if symdef.vartype == VarType.MEMORY:
return symdef.value.const_num_val()
return symdef.value.const_value()
raise TypeError("can only take constant address of a memory mapped variable", self)
raise TypeError("should be a vardef to be able to take its address", self)
@attr.s(cmp=False, slots=True)
class SymbolName(AstNode):
class SymbolName(Expression):
# no subnodes.
name = attr.ib(type=str)
def const_num_val(self) -> Union[int, float]:
def is_compile_constant(self) -> bool:
symdef = self.my_scope().lookup(self.name)
return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST
def const_value(self) -> Union[int, float, bool, str]:
symdef = self.my_scope().lookup(self.name)
if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST:
if symdef.datatype.isnumeric():
return symdef.const_num_val()
raise TypeError("not a constant value", self)
raise TypeError("should be a vardef to be able to take its constant numeric value", self)
return symdef.const_value()
raise TypeError("should be a const vardef to be able to take its constant numeric value", self)
@attr.s(cmp=False)
class Dereference(AstNode):
class Dereference(Expression):
# one subnode: operand (SymbolName, int or register name)
datatype = attr.ib()
size = attr.ib(type=int, default=None)
@ -452,6 +476,12 @@ class Dereference(AstNode):
raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref)
self.datatype = self.datatype.to_enum()
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("dereference is not a constant numeric value")
@attr.s(cmp=False)
class IncrDecr(AstNode):
@ -483,12 +513,12 @@ class IncrDecr(AstNode):
@attr.s(cmp=False, slots=True, repr=False)
class Expression(AstNode):
left = attr.ib()
class ExpressionWithOperator(Expression):
left = attr.ib() # type: Expression
operator = attr.ib(type=str)
right = attr.ib()
right = attr.ib() # type: Expression
unary = attr.ib(type=bool, default=False)
# when evaluating an expression, does it have to be a constant value?
# when evaluating the expression, does it have to be a compile-time constant value?
must_be_constant = attr.ib(type=bool, init=False, default=False)
def __attrs_post_init__(self):
@ -496,9 +526,12 @@ class Expression(AstNode):
if self.operator == "mod":
self.operator = "%" # change it back to the more common '%'
def const_num_val(self) -> Union[int, float]:
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("an expression is not a constant", self)
def is_compile_constant(self) -> bool:
return False
def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue:
# make sure the lvalue and rvalue are primitives, and the operator is allowed
assert isinstance(self.left, LiteralValue)
@ -513,19 +546,6 @@ class Expression(AstNode):
except Exception as x:
raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None
def print_tree(self) -> None:
def tree(expr: Any, level: int) -> str:
indent = " "*level
if not isinstance(expr, Expression):
return indent + str(expr) + "\n"
if expr.unary:
return indent + "{}{}".format(expr.operator, tree(expr.left, level+1))
else:
return indent + "{}".format(tree(expr.left, level+1)) + \
indent + str(expr.operator) + "\n" + \
indent + "{}".format(tree(expr.right, level + 1))
print(tree(self, 0))
@attr.s(cmp=False, repr=False)
class Goto(AstNode):
@ -537,7 +557,7 @@ class Goto(AstNode):
return self.nodes[0] # type: ignore
@property
def condition(self) -> Expression:
def condition(self) -> Optional[Expression]:
return self.nodes[1] if len(self.nodes) == 2 else None # type: ignore
@ -558,7 +578,7 @@ class CallArguments(AstNode):
@attr.s(cmp=False, repr=False)
class SubCall(AstNode):
class SubCall(Expression):
# has three subnodes:
# 0: target (Symbolname, int, or Dereference),
# 1: preserve_regs (PreserveRegs)
@ -576,10 +596,16 @@ class SubCall(AstNode):
def arguments(self) -> CallArguments:
return self.nodes[2] # type: ignore
def is_compile_constant(self) -> bool:
return False
def const_value(self) -> Union[int, float, bool, str]:
raise TypeError("subroutine call is not a constant value", self)
@attr.s(cmp=False, slots=True, repr=False)
class VarDef(AstNode):
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
# zero or one subnode: value (Expression).
name = attr.ib(type=str)
vartype = attr.ib()
datatype = attr.ib()
@ -587,29 +613,26 @@ class VarDef(AstNode):
zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
@property
def value(self) -> Union[LiteralValue, Expression, AddressOf, SymbolName]:
def value(self) -> Expression:
return self.nodes[0] if self.nodes else None # type: ignore
@value.setter
def value(self, value: Union[LiteralValue, Expression, AddressOf, SymbolName]) -> None:
assert isinstance(value, (LiteralValue, Expression, AddressOf, SymbolName))
def value(self, value: Expression) -> None:
assert isinstance(value, Expression)
if self.nodes:
self.nodes[0] = value
else:
self.nodes.append(value)
# if the value is an expression, mark it as a *constant* expression here
if isinstance(value, Expression):
if isinstance(value, ExpressionWithOperator):
# an expression in a vardef should evaluate to a compile-time constant:
value.must_be_constant = True
def const_num_val(self) -> Union[int, float]:
def const_value(self) -> Union[int, float, bool, str]:
if self.vartype != VarType.CONST:
raise TypeError("not a constant value", self)
if self.datatype.isnumeric():
if self.nodes:
return self.nodes[0].const_num_val() # type: ignore
raise ValueError("no value", self)
else:
raise TypeError("not numeric", self)
if self.nodes and isinstance(self.nodes[0], Expression):
return self.nodes[0].const_value()
raise ValueError("no value", self)
def __attrs_post_init__(self):
# convert vartype to enum
@ -647,15 +670,15 @@ class VarDef(AstNode):
class Return(AstNode):
# one, two or three subnodes: value_A, value_X, value_Y (all three Expression)
@property
def value_A(self) -> Expression:
def value_A(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore
@property
def value_X(self) -> Expression:
def value_X(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore
@property
def value_Y(self) -> Expression:
def value_Y(self) -> Optional[Expression]:
return self.nodes[0] if self.nodes else None # type: ignore
@ -709,20 +732,20 @@ class AssignmentTargets(AstNode):
@attr.s(cmp=False, slots=True, repr=False)
class Assignment(AstNode):
# can be single- or multi-assignment
# has two subnodes: left (=AssignmentTargets) and right (=reg/literal/expr
# or another Assignment but those will be converted to multi assign)
# has two subnodes: left (=AssignmentTargets) and right (=Expression,
# or another Assignment but those will be converted into multi assign)
@property
def left(self) -> AssignmentTargets:
return self.nodes[0] # type: ignore
@property
def right(self) -> Union[Register, LiteralValue, Expression]:
def right(self) -> Expression:
return self.nodes[1] # type: ignore
@right.setter
def right(self, rvalue: Union[Register, LiteralValue, Expression, Dereference, SymbolName, SubCall]) -> None:
assert isinstance(rvalue, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall))
def right(self, rvalue: Expression) -> None:
assert isinstance(rvalue, Expression)
self.nodes[1] = rvalue
@ -789,7 +812,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT):
if type(value.value) not in (int, float):
raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref)
elif isinstance(value, (Expression, SubCall)):
elif isinstance(value, (ExpressionWithOperator, SubCall)):
return False, value
elif isinstance(value, SymbolName):
symboldef = value.my_scope().lookup(value.name)
@ -797,7 +820,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
return True, symboldef.value
elif isinstance(value, AddressOf):
try:
address = value.const_num_val()
address = value.const_value()
return True, LiteralValue(value=address, sourceref=value.sourceref) # type: ignore
except TypeError:
return False, value
@ -1357,14 +1380,14 @@ def p_expression(p):
| expression EQUALS expression
| expression NOTEQUALS expression
"""
p[0] = Expression(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2))
p[0] = ExpressionWithOperator(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2))
def p_expression_uminus(p):
"""
expression : '-' expression %prec UNARY_MINUS
"""
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_expression_addressof(p):
@ -1378,14 +1401,14 @@ def p_unary_expression_bitinvert(p):
"""
expression : BITINVERT expression
"""
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_unary_expression_logicnot(p):
"""
expression : LOGICNOT expression
"""
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_expression_group(p):

View File

@ -1,6 +1,7 @@
import pytest
from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef
from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, Return, Scope, \
VarDef, Expression, LiteralValue, Label, SubCall, Dereference
from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \
VarDef, Expression, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference
from il65.datatypes import DataType
@ -127,7 +128,7 @@ def test_parser():
assert block.name == "block"
bool_vdef = block.scope.nodes[1]
assert isinstance(bool_vdef, VarDef)
assert isinstance(bool_vdef.value, Expression)
assert isinstance(bool_vdef.value, ExpressionWithOperator)
assert isinstance(bool_vdef.value.right, LiteralValue)
assert isinstance(bool_vdef.value.right.value, int)
assert bool_vdef.value.right.value == 1
@ -283,3 +284,11 @@ def test_boolean_int():
assert type(var2.value.value) is int and var2.value.value == 0
assert type(assgn1.right.value) is int and assgn1.right.value == 1
assert type(assgn2.right.value) is int and assgn2.right.value == 0
def test_incrdecr():
sref = SourceRef("test", 1, 1)
with pytest.raises(ValueError):
IncrDecr(operator="??", sourceref=sref)
i = IncrDecr(operator="++", sourceref=sref)
assert i.howmuch == 1

View File

@ -1,6 +1,6 @@
import pytest
from il65.datatypes import DataType
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, Expression, Scope, AddressOf, SymbolName, UndefinedSymbolError
from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError
from il65.plylex import SourceRef
# zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.).
@ -61,14 +61,14 @@ def test_set_value():
assert v.value is None
v.value = LiteralValue(value="hello", sourceref=sref)
assert v.value.value == "hello"
e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
assert not e.must_be_constant
v.value = e
assert v.value is e
assert e.must_be_constant
def test_const_num_val():
def test_const_value():
sref = SourceRef("test", 1, 1)
scope = Scope(nodes=[], level="block", sourceref=sref)
vardef = VarDef(name="constvar", vartype="const", datatype=None, sourceref=sref)
@ -82,37 +82,37 @@ def test_const_num_val():
scope.add_node(vardef)
v = VarDef(name="v1", vartype="var", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
with pytest.raises(TypeError):
v.const_num_val()
v.const_value()
v = VarDef(name="v1", vartype="memory", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
with pytest.raises(TypeError):
v.const_num_val()
v.const_value()
v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref)
assert v.const_num_val() == 0
assert v.const_value() == 0
v.value = LiteralValue(value=42, sourceref=sref)
assert v.const_num_val() == 42
assert v.const_value() == 42
v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="float", sourceref=sref), sourceref=sref)
assert v.const_num_val() == 0
assert v.const_value() == 0
v.value = LiteralValue(value=42.9988, sourceref=sref)
assert v.const_num_val() == 42.9988
e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
assert v.const_value() == 42.9988
e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref)
v.value = e
with pytest.raises(TypeError):
v.const_num_val()
v.const_value()
s = SymbolName(name="unexisting", sourceref=sref)
s.parent = scope
v.value = s
with pytest.raises(UndefinedSymbolError):
v.const_num_val()
v.const_value()
s = SymbolName(name="constvar", sourceref=sref)
s.parent = scope
v.value = s
assert v.const_num_val() == 43
assert v.const_value() == 43
a = AddressOf(name="varvar", sourceref=sref)
a.parent = scope
v.value = a
with pytest.raises(TypeError):
v.const_num_val()
v.const_value()
a = AddressOf(name="memvar", sourceref=sref)
a.parent = scope
v.value = a
assert v.const_num_val() == 45
assert v.const_value() == 45

View File

@ -5,16 +5,15 @@
start:
Y+=10
Y-=5
Y-=8
Y--
flt+=2
flt+=2
flt+=2
flt+=2
flt+=2
;flt+=2 ; @todo implement on float
;flt+=2
;flt+=2
;flt+=2
;flt+=2
X=0
X+=5
@ -27,7 +26,8 @@ start:
X=24|X
X=X^66
X+=250
X+=5
X=5+2+3+4
X+=5+2+3+4
X-=100
X-=50
X-=5
@ -52,6 +52,7 @@ start:
;flt += 10.1
;flt += 100.1
;flt += 1000.1
;flt *= 2.34
return 44
}