diff --git a/il65/compile.py b/il65/compile.py index 3e07b2c86..d64384732 100644 --- a/il65/compile.py +++ b/il65/compile.py @@ -192,7 +192,7 @@ class PlyParser: lvalue_types = set(datatype_of(lv, node.my_scope()) for lv in node.left.nodes) if len(lvalue_types) == 1: _, newright = coerce_constant_value(lvalue_types.pop(), node.right, node.sourceref) - if isinstance(newright, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall)): + if isinstance(newright, Expression): node.right = newright # type: ignore else: raise TypeError("invalid coerced constant type", newright) @@ -322,7 +322,7 @@ class PlyParser: return elif isinstance(expr, SubCall): self._get_subroutine_usages_from_subcall(usages, expr, parent_scope) - elif isinstance(expr, Expression): + elif isinstance(expr, ExpressionWithOperator): self._get_subroutine_usages_from_expression(usages, expr.left, parent_scope) self._get_subroutine_usages_from_expression(usages, expr.right, parent_scope) elif isinstance(expr, LiteralValue): diff --git a/il65/emit/assignment.py b/il65/emit/assignment.py index 2e781e3d9..ef3b8d19e 100644 --- a/il65/emit/assignment.py +++ b/il65/emit/assignment.py @@ -36,7 +36,7 @@ def generate_aug_assignment(out: Callable, stmt: AugAssignment, scope: Scope) -> elif isinstance(rvalue, SymbolName): symdef = scope.lookup(rvalue.name) if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST and symdef.datatype.isinteger(): - if 0 <= symdef.value.const_num_val() <= 255: + if 0 <= symdef.value.const_value() <= 255: # type: ignore _generate_aug_reg_constant_int(out, lvalue, stmt.operator, 0, symdef.name, scope) else: raise CodeError("assignment value must be 0..255", rvalue) @@ -136,7 +136,7 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st else: raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo ^=.word elif operator == ">>=": - if rvalue.value > 0: + if rvalue > 0: def shifts_A(times: int) -> None: if times >= 8: out("\vlda #0") @@ -144,21 +144,21 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st for _ in range(min(8, times)): out("\vlsr a") if lvalue.name == "A": - shifts_A(rvalue.value) + shifts_A(rvalue) elif lvalue.name == "X": with preserving_registers({'A'}, scope, out): out("\vtxa") - shifts_A(rvalue.value) + shifts_A(rvalue) out("\vtax") elif lvalue.name == "Y": with preserving_registers({'A'}, scope, out): out("\vtya") - shifts_A(rvalue.value) + shifts_A(rvalue) out("\vtay") else: raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo >>=.word elif operator == "<<=": - if rvalue.value > 0: + if rvalue > 0: def shifts_A(times: int) -> None: if times >= 8: out("\vlda #0") @@ -166,16 +166,16 @@ def _generate_aug_reg_constant_int(out: Callable, lvalue: Register, operator: st for _ in range(min(8, times)): out("\vasl a") if lvalue.name == "A": - shifts_A(rvalue.value) + shifts_A(rvalue) elif lvalue.name == "X": with preserving_registers({'A'}, scope, out): out("\vtxa") - shifts_A(rvalue.value) + shifts_A(rvalue) out("\vtax") elif lvalue.name == "Y": with preserving_registers({'A'}, scope, out): out("\vtya") - shifts_A(rvalue.value) + shifts_A(rvalue) out("\vtay") else: raise CodeError("unsupported register for aug assign", str(lvalue)) # @todo <<=.word diff --git a/il65/optimize.py b/il65/optimize.py index bccd8472d..3088874ab 100644 --- a/il65/optimize.py +++ b/il65/optimize.py @@ -58,11 +58,14 @@ class Optimizer: def constant_folding(self) -> None: for expression in self.module.all_nodes(Expression): + if isinstance(expression, LiteralValue): + continue try: - evaluated = process_expression(expression, expression.sourceref) # type: ignore + evaluated = process_expression(expression) # type: ignore if evaluated is not expression: # replace the node with the newly evaluated result expression.parent.replace_node(expression, evaluated) + self.optimizations_performed = True except ParseError: raise except Exception as x: @@ -148,12 +151,12 @@ class Optimizer: for assignment in self.module.all_nodes(Assignment): if len(assignment.left.nodes) > 1: continue - if not isinstance(assignment.right, Expression) or assignment.right.unary: + if not isinstance(assignment.right, ExpressionWithOperator) or assignment.right.unary: continue expr = assignment.right if expr.operator in ('-', '/', '//', '**', '<<', '>>', '&'): # non-associative operators if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): - num_val = expr.right.const_num_val() + num_val = expr.right.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) @@ -162,13 +165,13 @@ class Optimizer: if expr.operator not in ('+', '*', '|', '^'): # associative operators continue if isinstance(expr.right, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.left): - num_val = expr.right.const_num_val() + num_val = expr.right.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) self.optimizations_performed = True elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right): - num_val = expr.left.const_num_val() + num_val = expr.left.const_value() operator = expr.operator + '=' aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator) assignment.my_scope().replace_node(assignment, aug_assign) @@ -189,6 +192,7 @@ class Optimizer: print_warning("{}: removed superfluous assignment".format(prev_node.sourceref)) prev_node = node + @no_type_check def optimize_assignments(self) -> None: # remove assignment statements that do nothing (A=A) # remove augmented assignments that have no effect (x+=0, x-=0, x/=1, x//=1, x*=1) @@ -385,27 +389,27 @@ class Optimizer: node.my_scope().nodes.remove(node) -def process_expression(expr: Expression, sourceref: SourceRef) -> Any: +def process_expression(expr: Expression) -> Any: # process/simplify all expressions (constant folding etc) - if expr.must_be_constant: - return process_constant_expression(expr, sourceref) + if expr.is_compile_constant() or isinstance(expr, ExpressionWithOperator) and expr.must_be_constant: + return process_constant_expression(expr, expr.sourceref) else: - return process_dynamic_expression(expr, sourceref) + return process_dynamic_expression(expr, expr.sourceref) -def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue: +def process_constant_expression(expr: Expression, sourceref: SourceRef) -> LiteralValue: # the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue. - if isinstance(expr, (int, float, str, bool)): - raise TypeError("expr node should not be a python primitive value", expr, sourceref) - elif expr is None or isinstance(expr, LiteralValue): + if isinstance(expr, LiteralValue): return expr + if expr.is_compile_constant(): + return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore elif isinstance(expr, SymbolName): value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref) if isinstance(value, VarDef): if value.vartype == VarType.MEMORY: raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref) value = value.value - if isinstance(value, Expression): + if isinstance(value, ExpressionWithOperator): raise ExpressionEvaluationError("circular reference?", expr.sourceref) elif isinstance(value, LiteralValue): return value @@ -452,45 +456,46 @@ def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref) else: raise NotImplementedError("weird call target", expr.target) - elif not isinstance(expr, Expression): - raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) - if expr.unary: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_constant_expression(expr.left, left_sourceref) - if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): - try: - if expr.operator == '-': - return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore - elif expr.operator == '~': - return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore - elif expr.operator in ("++", "--"): - raise ValueError("incr/decr should not be an expression") - raise ValueError("invalid unary operator", expr.operator) - except TypeError as x: - raise ParseError(str(x), expr.sourceref) from None - raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) - else: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_constant_expression(expr.left, left_sourceref) - right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref - expr.right = process_constant_expression(expr.right, right_sourceref) - if isinstance(expr.left, LiteralValue): - if isinstance(expr.right, LiteralValue): - return expr.evaluate_primitive_constants(expr.right.sourceref) - else: - raise ExpressionEvaluationError("constant literal value required on right, not {}" - .format(expr.right.__class__.__name__), right_sourceref) + elif isinstance(expr, ExpressionWithOperator): + if expr.unary: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_constant_expression(expr.left, left_sourceref) + if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float): + try: + if expr.operator == '-': + return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore + elif expr.operator == '~': + return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore + elif expr.operator in ("++", "--"): + raise ValueError("incr/decr should not be an expression") + raise ValueError("invalid unary operator", expr.operator) + except TypeError as x: + raise ParseError(str(x), expr.sourceref) from None + raise ValueError("invalid operand type for unary operator", expr.left, expr.operator) else: - raise ExpressionEvaluationError("constant literal value required on left, not {}" - .format(expr.left.__class__.__name__), left_sourceref) + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_constant_expression(expr.left, left_sourceref) + right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref + expr.right = process_constant_expression(expr.right, right_sourceref) + if isinstance(expr.left, LiteralValue): + if isinstance(expr.right, LiteralValue): + return expr.evaluate_primitive_constants(expr.right.sourceref) + else: + raise ExpressionEvaluationError("constant literal value required on right, not {}" + .format(expr.right.__class__.__name__), right_sourceref) + else: + raise ExpressionEvaluationError("constant literal value required on left, not {}" + .format(expr.left.__class__.__name__), left_sourceref) + else: + raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref) -def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any: +def process_dynamic_expression(expr: Expression, sourceref: SourceRef) -> Any: # constant-fold a dynamic expression - if isinstance(expr, (int, float, str, bool)): - raise TypeError("expr node should not be a python primitive value", expr, sourceref) - elif expr is None or isinstance(expr, LiteralValue): + if isinstance(expr, LiteralValue): return expr + if expr.is_compile_constant(): + return LiteralValue(value=expr.const_value(), sourceref=sourceref) # type: ignore elif isinstance(expr, SymbolName): try: return process_constant_expression(expr, sourceref) @@ -514,24 +519,25 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any: if isinstance(expr.operand, SymbolName): check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref) return expr - elif not isinstance(expr, Expression): - raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) - if expr.unary: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_dynamic_expression(expr.left, left_sourceref) - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr + elif isinstance(expr, ExpressionWithOperator): + if expr.unary: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_dynamic_expression(expr.left, left_sourceref) + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr + else: + left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref + expr.left = process_dynamic_expression(expr.left, left_sourceref) + right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref + expr.right = process_dynamic_expression(expr.right, right_sourceref) + try: + return process_constant_expression(expr, sourceref) + except ExpressionEvaluationError: + return expr else: - left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref - expr.left = process_dynamic_expression(expr.left, left_sourceref) - right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref - expr.right = process_dynamic_expression(expr.right, right_sourceref) - try: - return process_constant_expression(expr, sourceref) - except ExpressionEvaluationError: - return expr + raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref) def optimize(mod: Module) -> None: diff --git a/il65/plyparse.py b/il65/plyparse.py index e5419a07b..f4437428d 100644 --- a/il65/plyparse.py +++ b/il65/plyparse.py @@ -77,10 +77,11 @@ class AstNode: def all_nodes(self, *nodetypes: type) -> Generator['AstNode', None, None]: nodetypes = nodetypes or (AstNode, ) - for node in list(self.nodes): - if isinstance(node, nodetypes): # type: ignore + child_nodes = list(self.nodes) + for node in child_nodes: + if isinstance(node, nodetypes): yield node - for node in self.nodes: + for node in child_nodes: if isinstance(node, AstNode): yield from node.all_nodes(*nodetypes) @@ -293,8 +294,19 @@ class Label(AstNode): # no subnodes. +@attr.s(cmp=False, slots=True, repr=False) +class Expression(AstNode): + # just a common base class for the nodes that are an expression themselves: + # ExpressionWithOperator, AddressOf, LiteralValue, SymbolName, Register, SubCall, Dereference + def is_compile_constant(self) -> bool: + raise NotImplementedError("implement in subclass") + + def const_value(self) -> Union[int, float, bool, str]: + raise NotImplementedError("implement in subclass") + + @attr.s(cmp=False, slots=True) -class Register(AstNode): +class Register(Expression): name = attr.ib(type=str, validator=attr.validators.in_(REGISTER_SYMBOLS)) datatype = attr.ib(type=DataType, init=False) # no subnodes. @@ -320,6 +332,12 @@ class Register(AstNode): return NotImplemented return self.name < other.name + def is_compile_constant(self) -> bool: + return False + + def const_value(self) -> Union[int, float, bool, str]: + raise TypeError("register doesn't have a constant numeric value", self) + @attr.s(cmp=False) class PreserveRegs(AstNode): @@ -386,51 +404,57 @@ class Subroutine(AstNode): @attr.s(cmp=True, slots=True, repr=False) -class LiteralValue(AstNode): +class LiteralValue(Expression): # no subnodes. value = attr.ib() def __repr__(self) -> str: return "".format(self.value, self.sourceref) - def const_num_val(self) -> Union[int, float]: - if isinstance(self.value, (int, float)): - return self.value - raise TypeError("literal value is not numeric", self) + def const_value(self) -> Union[int, float, bool, str]: + return self.value + + def is_compile_constant(self) -> bool: + return True @attr.s(cmp=False) -class AddressOf(AstNode): +class AddressOf(Expression): # no subnodes. name = attr.ib(type=str) - def const_num_val(self) -> Union[int, float]: + def is_compile_constant(self) -> bool: + return False + + def const_value(self) -> Union[int, float, bool, str]: symdef = self.my_scope().lookup(self.name) if isinstance(symdef, VarDef): if symdef.zp_address is not None: return symdef.zp_address if symdef.vartype == VarType.MEMORY: - return symdef.value.const_num_val() + return symdef.value.const_value() raise TypeError("can only take constant address of a memory mapped variable", self) raise TypeError("should be a vardef to be able to take its address", self) @attr.s(cmp=False, slots=True) -class SymbolName(AstNode): +class SymbolName(Expression): # no subnodes. name = attr.ib(type=str) - def const_num_val(self) -> Union[int, float]: + def is_compile_constant(self) -> bool: + symdef = self.my_scope().lookup(self.name) + return isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST + + def const_value(self) -> Union[int, float, bool, str]: symdef = self.my_scope().lookup(self.name) if isinstance(symdef, VarDef) and symdef.vartype == VarType.CONST: - if symdef.datatype.isnumeric(): - return symdef.const_num_val() - raise TypeError("not a constant value", self) - raise TypeError("should be a vardef to be able to take its constant numeric value", self) + return symdef.const_value() + raise TypeError("should be a const vardef to be able to take its constant numeric value", self) @attr.s(cmp=False) -class Dereference(AstNode): +class Dereference(Expression): # one subnode: operand (SymbolName, int or register name) datatype = attr.ib() size = attr.ib(type=int, default=None) @@ -452,6 +476,12 @@ class Dereference(AstNode): raise ParseError("dereference target value must be byte, word, float", self.datatype.sourceref) self.datatype = self.datatype.to_enum() + def is_compile_constant(self) -> bool: + return False + + def const_value(self) -> Union[int, float, bool, str]: + raise TypeError("dereference is not a constant numeric value") + @attr.s(cmp=False) class IncrDecr(AstNode): @@ -483,12 +513,12 @@ class IncrDecr(AstNode): @attr.s(cmp=False, slots=True, repr=False) -class Expression(AstNode): - left = attr.ib() +class ExpressionWithOperator(Expression): + left = attr.ib() # type: Expression operator = attr.ib(type=str) - right = attr.ib() + right = attr.ib() # type: Expression unary = attr.ib(type=bool, default=False) - # when evaluating an expression, does it have to be a constant value? + # when evaluating the expression, does it have to be a compile-time constant value? must_be_constant = attr.ib(type=bool, init=False, default=False) def __attrs_post_init__(self): @@ -496,9 +526,12 @@ class Expression(AstNode): if self.operator == "mod": self.operator = "%" # change it back to the more common '%' - def const_num_val(self) -> Union[int, float]: + def const_value(self) -> Union[int, float, bool, str]: raise TypeError("an expression is not a constant", self) + def is_compile_constant(self) -> bool: + return False + def evaluate_primitive_constants(self, sourceref: SourceRef) -> LiteralValue: # make sure the lvalue and rvalue are primitives, and the operator is allowed assert isinstance(self.left, LiteralValue) @@ -513,19 +546,6 @@ class Expression(AstNode): except Exception as x: raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None - def print_tree(self) -> None: - def tree(expr: Any, level: int) -> str: - indent = " "*level - if not isinstance(expr, Expression): - return indent + str(expr) + "\n" - if expr.unary: - return indent + "{}{}".format(expr.operator, tree(expr.left, level+1)) - else: - return indent + "{}".format(tree(expr.left, level+1)) + \ - indent + str(expr.operator) + "\n" + \ - indent + "{}".format(tree(expr.right, level + 1)) - print(tree(self, 0)) - @attr.s(cmp=False, repr=False) class Goto(AstNode): @@ -537,7 +557,7 @@ class Goto(AstNode): return self.nodes[0] # type: ignore @property - def condition(self) -> Expression: + def condition(self) -> Optional[Expression]: return self.nodes[1] if len(self.nodes) == 2 else None # type: ignore @@ -558,7 +578,7 @@ class CallArguments(AstNode): @attr.s(cmp=False, repr=False) -class SubCall(AstNode): +class SubCall(Expression): # has three subnodes: # 0: target (Symbolname, int, or Dereference), # 1: preserve_regs (PreserveRegs) @@ -576,10 +596,16 @@ class SubCall(AstNode): def arguments(self) -> CallArguments: return self.nodes[2] # type: ignore + def is_compile_constant(self) -> bool: + return False + + def const_value(self) -> Union[int, float, bool, str]: + raise TypeError("subroutine call is not a constant value", self) + @attr.s(cmp=False, slots=True, repr=False) class VarDef(AstNode): - # zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.). + # zero or one subnode: value (Expression). name = attr.ib(type=str) vartype = attr.ib() datatype = attr.ib() @@ -587,29 +613,26 @@ class VarDef(AstNode): zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later @property - def value(self) -> Union[LiteralValue, Expression, AddressOf, SymbolName]: + def value(self) -> Expression: return self.nodes[0] if self.nodes else None # type: ignore @value.setter - def value(self, value: Union[LiteralValue, Expression, AddressOf, SymbolName]) -> None: - assert isinstance(value, (LiteralValue, Expression, AddressOf, SymbolName)) + def value(self, value: Expression) -> None: + assert isinstance(value, Expression) if self.nodes: self.nodes[0] = value else: self.nodes.append(value) - # if the value is an expression, mark it as a *constant* expression here - if isinstance(value, Expression): + if isinstance(value, ExpressionWithOperator): + # an expression in a vardef should evaluate to a compile-time constant: value.must_be_constant = True - def const_num_val(self) -> Union[int, float]: + def const_value(self) -> Union[int, float, bool, str]: if self.vartype != VarType.CONST: raise TypeError("not a constant value", self) - if self.datatype.isnumeric(): - if self.nodes: - return self.nodes[0].const_num_val() # type: ignore - raise ValueError("no value", self) - else: - raise TypeError("not numeric", self) + if self.nodes and isinstance(self.nodes[0], Expression): + return self.nodes[0].const_value() + raise ValueError("no value", self) def __attrs_post_init__(self): # convert vartype to enum @@ -647,15 +670,15 @@ class VarDef(AstNode): class Return(AstNode): # one, two or three subnodes: value_A, value_X, value_Y (all three Expression) @property - def value_A(self) -> Expression: + def value_A(self) -> Optional[Expression]: return self.nodes[0] if self.nodes else None # type: ignore @property - def value_X(self) -> Expression: + def value_X(self) -> Optional[Expression]: return self.nodes[0] if self.nodes else None # type: ignore @property - def value_Y(self) -> Expression: + def value_Y(self) -> Optional[Expression]: return self.nodes[0] if self.nodes else None # type: ignore @@ -709,20 +732,20 @@ class AssignmentTargets(AstNode): @attr.s(cmp=False, slots=True, repr=False) class Assignment(AstNode): # can be single- or multi-assignment - # has two subnodes: left (=AssignmentTargets) and right (=reg/literal/expr - # or another Assignment but those will be converted to multi assign) + # has two subnodes: left (=AssignmentTargets) and right (=Expression, + # or another Assignment but those will be converted into multi assign) @property def left(self) -> AssignmentTargets: return self.nodes[0] # type: ignore @property - def right(self) -> Union[Register, LiteralValue, Expression]: + def right(self) -> Expression: return self.nodes[1] # type: ignore @right.setter - def right(self, rvalue: Union[Register, LiteralValue, Expression, Dereference, SymbolName, SubCall]) -> None: - assert isinstance(rvalue, (Register, LiteralValue, Expression, Dereference, SymbolName, SubCall)) + def right(self, rvalue: Expression) -> None: + assert isinstance(rvalue, Expression) self.nodes[1] = rvalue @@ -789,7 +812,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode, elif datatype in (DataType.BYTE, DataType.WORD, DataType.FLOAT): if type(value.value) not in (int, float): raise TypeError("cannot assign '{:s}' to {:s}".format(type(value.value).__name__, datatype.name.lower()), sourceref) - elif isinstance(value, (Expression, SubCall)): + elif isinstance(value, (ExpressionWithOperator, SubCall)): return False, value elif isinstance(value, SymbolName): symboldef = value.my_scope().lookup(value.name) @@ -797,7 +820,7 @@ def coerce_constant_value(datatype: DataType, value: AstNode, return True, symboldef.value elif isinstance(value, AddressOf): try: - address = value.const_num_val() + address = value.const_value() return True, LiteralValue(value=address, sourceref=value.sourceref) # type: ignore except TypeError: return False, value @@ -1357,14 +1380,14 @@ def p_expression(p): | expression EQUALS expression | expression NOTEQUALS expression """ - p[0] = Expression(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) + p[0] = ExpressionWithOperator(left=p[1], operator=p[2], right=p[3], sourceref=_token_sref(p, 2)) def p_expression_uminus(p): """ expression : '-' expression %prec UNARY_MINUS """ - p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) + p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) def p_expression_addressof(p): @@ -1378,14 +1401,14 @@ def p_unary_expression_bitinvert(p): """ expression : BITINVERT expression """ - p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) + p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) def p_unary_expression_logicnot(p): """ expression : LOGICNOT expression """ - p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) + p[0] = ExpressionWithOperator(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1)) def p_expression_group(p): diff --git a/tests/test_parser.py b/tests/test_parser.py index 4b916bfe7..ca54b0e21 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,6 +1,7 @@ +import pytest from il65.plylex import lexer, tokens, find_tok_column, literals, reserved, SourceRef -from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, Return, Scope, \ - VarDef, Expression, LiteralValue, Label, SubCall, Dereference +from il65.plyparse import parser, connect_parents, TokenFilter, Module, Subroutine, Block, IncrDecr, Scope, \ + VarDef, Expression, ExpressionWithOperator, LiteralValue, Label, SubCall, Dereference from il65.datatypes import DataType @@ -127,7 +128,7 @@ def test_parser(): assert block.name == "block" bool_vdef = block.scope.nodes[1] assert isinstance(bool_vdef, VarDef) - assert isinstance(bool_vdef.value, Expression) + assert isinstance(bool_vdef.value, ExpressionWithOperator) assert isinstance(bool_vdef.value.right, LiteralValue) assert isinstance(bool_vdef.value.right.value, int) assert bool_vdef.value.right.value == 1 @@ -283,3 +284,11 @@ def test_boolean_int(): assert type(var2.value.value) is int and var2.value.value == 0 assert type(assgn1.right.value) is int and assgn1.right.value == 1 assert type(assgn2.right.value) is int and assgn2.right.value == 0 + + +def test_incrdecr(): + sref = SourceRef("test", 1, 1) + with pytest.raises(ValueError): + IncrDecr(operator="??", sourceref=sref) + i = IncrDecr(operator="++", sourceref=sref) + assert i.howmuch == 1 diff --git a/tests/test_vardef.py b/tests/test_vardef.py index f31a3d5ac..512399219 100644 --- a/tests/test_vardef.py +++ b/tests/test_vardef.py @@ -1,6 +1,6 @@ import pytest from il65.datatypes import DataType -from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, Expression, Scope, AddressOf, SymbolName, UndefinedSymbolError +from il65.plyparse import LiteralValue, VarDef, VarType, DatatypeNode, ExpressionWithOperator, Scope, AddressOf, SymbolName, UndefinedSymbolError from il65.plylex import SourceRef # zero or one subnode: value (an Expression, LiteralValue, AddressOf or SymbolName.). @@ -61,14 +61,14 @@ def test_set_value(): assert v.value is None v.value = LiteralValue(value="hello", sourceref=sref) assert v.value.value == "hello" - e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) + e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) assert not e.must_be_constant v.value = e assert v.value is e assert e.must_be_constant -def test_const_num_val(): +def test_const_value(): sref = SourceRef("test", 1, 1) scope = Scope(nodes=[], level="block", sourceref=sref) vardef = VarDef(name="constvar", vartype="const", datatype=None, sourceref=sref) @@ -82,37 +82,37 @@ def test_const_num_val(): scope.add_node(vardef) v = VarDef(name="v1", vartype="var", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) with pytest.raises(TypeError): - v.const_num_val() + v.const_value() v = VarDef(name="v1", vartype="memory", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) with pytest.raises(TypeError): - v.const_num_val() + v.const_value() v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="word", sourceref=sref), sourceref=sref) - assert v.const_num_val() == 0 + assert v.const_value() == 0 v.value = LiteralValue(value=42, sourceref=sref) - assert v.const_num_val() == 42 + assert v.const_value() == 42 v = VarDef(name="v1", vartype="const", datatype=DatatypeNode(name="float", sourceref=sref), sourceref=sref) - assert v.const_num_val() == 0 + assert v.const_value() == 0 v.value = LiteralValue(value=42.9988, sourceref=sref) - assert v.const_num_val() == 42.9988 - e = Expression(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) + assert v.const_value() == 42.9988 + e = ExpressionWithOperator(left=LiteralValue(value=42, sourceref=sref), operator="-", unary=True, right=None, sourceref=sref) v.value = e with pytest.raises(TypeError): - v.const_num_val() + v.const_value() s = SymbolName(name="unexisting", sourceref=sref) s.parent = scope v.value = s with pytest.raises(UndefinedSymbolError): - v.const_num_val() + v.const_value() s = SymbolName(name="constvar", sourceref=sref) s.parent = scope v.value = s - assert v.const_num_val() == 43 + assert v.const_value() == 43 a = AddressOf(name="varvar", sourceref=sref) a.parent = scope v.value = a with pytest.raises(TypeError): - v.const_num_val() + v.const_value() a = AddressOf(name="memvar", sourceref=sref) a.parent = scope v.value = a - assert v.const_num_val() == 45 + assert v.const_value() == 45 diff --git a/todo.ill b/todo.ill index 78f04b1ab..01e152d33 100644 --- a/todo.ill +++ b/todo.ill @@ -5,16 +5,15 @@ start: - Y+=10 Y-=5 Y-=8 Y-- - flt+=2 - flt+=2 - flt+=2 - flt+=2 - flt+=2 + ;flt+=2 ; @todo implement on float + ;flt+=2 + ;flt+=2 + ;flt+=2 + ;flt+=2 X=0 X+=5 @@ -27,7 +26,8 @@ start: X=24|X X=X^66 X+=250 - X+=5 + X=5+2+3+4 + X+=5+2+3+4 X-=100 X-=50 X-=5 @@ -52,6 +52,7 @@ start: ;flt += 10.1 ;flt += 100.1 ;flt += 1000.1 + ;flt *= 2.34 return 44 }