expressions

This commit is contained in:
Irmen de Jong 2018-01-18 23:33:02 +01:00
parent 861379c4d7
commit 67f1941766
3 changed files with 94 additions and 113 deletions

View File

@ -14,7 +14,7 @@ import attr
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \
check_symbol_definition, UndefinedSymbolError
check_symbol_definition, UndefinedSymbolError, process_expression
from .plylex import SourceRef, print_bold
from .datatypes import DataType, VarType
@ -38,7 +38,7 @@ class PlyParser:
self.check_all_symbolnames(module)
self.create_multiassigns(module)
self.check_and_merge_zeropages(module)
self.process_all_expressions_and_symbolnames(module)
self.process_all_expressions(module)
return module # XXX
# if not self.parsing_import:
# # these shall only be done on the main module after all imports have been done:
@ -152,7 +152,7 @@ class PlyParser:
for node in module.all_nodes([SymbolName]):
check_symbol_definition(node.name, node.my_scope(), node.sourceref)
def process_all_expressions_and_symbolnames(self, module: Module) -> None:
def process_all_expressions(self, module: Module) -> None:
# process/simplify all expressions (constant folding etc)
encountered_blocks = set()
for node in module.all_nodes():
@ -163,13 +163,12 @@ class PlyParser:
raise ValueError("block names not unique:", blockname)
encountered_blocks.add(blockname)
elif isinstance(node, Expression):
print("EXPRESSION", node) # XXX
# try:
# node.process_expressions(block.scope)
# except ParseError:
# raise
# except Exception as x:
# self.handle_internal_error(x, "process_expressions of node {} in block {}".format(node, block.name))
try:
process_expression(node, node.my_scope(), node.sourceref)
except ParseError:
raise
except Exception as x:
self.handle_internal_error(x, "process_expressions of node {}".format(node))
elif isinstance(node, IncrDecr) and node.howmuch not in (0, 1):
_, node.howmuch = coerce_constant_value(datatype_of(node.target, node.my_scope()), node.howmuch, node.sourceref)
elif isinstance(node, Assignment):

View File

@ -59,10 +59,13 @@ tokens = (
"BITOR",
"BITXOR",
"BITINVERT",
"SHIFTLEFT",
"SHIFTRIGHT",
"LOGICAND",
"LOGICOR",
"LOGICNOT",
"INTEGERDIVIDE",
"MODULO",
"POWER",
"LABEL",
"IF",
@ -75,6 +78,8 @@ literals = ['+', '-', '*', '/', '(', ')', '[', ']', '{', '}', '.', ',', '!', '?'
# regex rules for simple tokens
t_SHIFTLEFT = r"<<"
t_SHIFTRIGHT = r">>"
t_INTEGERDIVIDE = r"//"
t_BITAND = r"&"
t_BITOR = r"\|"
@ -118,6 +123,7 @@ reserved = {
"not": "LOGICNOT",
"and": "LOGICAND",
"or": "LOGICOR",
"mod": "MODULO",
"AX": "REGISTER",
"AY": "REGISTER",
"XY": "REGISTER",

View File

@ -61,8 +61,6 @@ class AstNode:
sourceref = attr.ib(type=SourceRef)
parent = attr.ib(init=False, default=None) # will be hooked up later
nodes = attr.ib(type=list, init=False, default=attr.Factory(list)) # type: List['AstNode']
# when evaluating an expression, does it have to be a constant value?:
processed_expr_must_be_constant = attr.ib(type=bool, init=False, default=False)
@property
def lineref(self) -> str:
@ -70,7 +68,7 @@ class AstNode:
def my_scope(self) -> 'Scope':
# returns the closest Scope in the ancestry of this node, or raises LookupError if no scope is found
scope = self
scope = self.parent
while scope:
if isinstance(scope, Scope):
return scope
@ -104,11 +102,6 @@ class AstNode:
else:
self.nodes.insert(index, newnode)
def process_expressions(self, scope: 'Scope') -> None: # XXX remove, use all_nodes
# process/simplify all expressions (constant folding etc)
# this is implemented in node types that have expression(s) and that should act on this.
pass
@attr.s(cmp=False)
class Directive(AstNode):
@ -125,7 +118,17 @@ class Scope(AstNode):
symbols = attr.ib(init=False)
name = attr.ib(init=False) # will be set by enclosing block, or subroutine etc.
parent_scope = attr.ib(init=False, default=None) # will be wired up later
save_registers = attr.ib(type=bool, default=None, init=False) # None = look in parent scope's setting @todo property that does that
_save_registers = attr.ib(type=bool, default=None, init=False)
@property
def save_registers(self) -> bool:
if self._save_registers is not None:
return self._save_registers
return self.my_scope().save_registers
@save_registers.setter
def save_registers(self, save: bool) -> None:
self._save_registers = save
def __attrs_post_init__(self):
# populate the symbol table for this scope for fast lookups via scope.lookup("name") or scope.lookup("dotted.name")
@ -238,6 +241,7 @@ class Block(AstNode):
@scope.setter
def scope(self, scope: Scope) -> None:
assert isinstance(scope, Scope)
self.nodes.clear()
self.nodes.append(scope)
scope.name = self.name
@ -373,6 +377,7 @@ class Subroutine(AstNode):
@scope.setter
def scope(self, scope: Scope) -> None:
assert isinstance(scope, Scope)
self.nodes.clear()
self.nodes.append(scope)
scope.name = self.name
@ -385,10 +390,6 @@ class Goto(AstNode):
# one or two subnodes: target (SymbolName, int or Dereference) and optionally: condition (Expression)
if_stmt = attr.ib(default=None)
def process_expressions(self, scope: Scope) -> None:
if len(self.nodes) == 2:
self.nodes[1] = process_expression(self.nodes[1], scope, self.nodes[1].sourceref)
@attr.s(cmp=True, slots=True)
class LiteralValue(AstNode):
@ -450,6 +451,7 @@ class IncrDecr(AstNode):
raise ParseError("cannot incr/decr that register", self.sourceref)
if isinstance(target, TargetRegisters):
raise ParseError("cannot incr/decr multiple registers at once", self.sourceref)
assert isinstance(target, (Register, SymbolName, Dereference))
self.nodes.clear()
self.nodes.append(target)
@ -466,22 +468,23 @@ class Expression(AstNode):
operator = attr.ib(type=str)
right = attr.ib()
unary = attr.ib(type=bool, default=False)
# when evaluating an expression, does it have to be a constant value?
must_be_constant = attr.ib(type=bool, init=False, default=False)
def __attrs_post_init__(self):
assert self.operator not in ("++", "--"), "incr/decr should not be an expression"
if self.operator == "mod":
self.operator = "%" # change it back to the more common '%'
def process_expressions(self, scope: Scope) -> None:
raise RuntimeError("must be done via parent node's process_expressions")
def evaluate_primitive_constants(self, scope: Scope) -> LiteralValue:
def evaluate_primitive_constants(self, scope: Scope, sourceref: SourceRef) -> LiteralValue:
# make sure the lvalue and rvalue are primitives, and the operator is allowed
assert isinstance(self.left, LiteralValue)
assert isinstance(self.right, LiteralValue)
if self.operator not in {'+', '-', '*', '/', '//', '~', '<', '>', '<=', '>=', '==', '!='}:
raise ValueError("operator", self)
if self.operator not in {'+', '-', '*', '/', '//', '~', '|', '&', '%', '<<', '>>', '<', '>', '<=', '>=', '==', '!='}:
raise ValueError("operator", self.operator)
estr = "{} {} {}".format(repr(self.left.value), self.operator, repr(self.right.value))
try:
return eval(estr, {}, {}) # safe because of checks above
return LiteralValue(value=eval(estr, {}, {}), sourceref=sourceref) # type: ignore # safe because of checks above
except Exception as x:
raise ExpressionEvaluationError("expression error: " + str(x), self.sourceref) from None
@ -508,9 +511,6 @@ class CallArgument(AstNode):
def value(self) -> Expression:
return self.nodes[0] # type: ignore
def process_expressions(self, scope: Scope) -> None:
self.nodes[0] = process_expression(self.nodes[0], scope, self.sourceref)
@attr.s(cmp=False)
class CallArguments(AstNode):
@ -537,11 +537,6 @@ class SubCall(AstNode):
def arguments(self) -> CallArguments:
return self.nodes[2] # type: ignore
def process_expressions(self, scope: Scope) -> None:
for callarg in self.nodes[2].nodes:
assert isinstance(callarg, CallArgument)
callarg.process_expressions(scope)
@attr.s(cmp=False, slots=True, repr=False)
class VarDef(AstNode):
@ -553,18 +548,19 @@ class VarDef(AstNode):
zp_address = attr.ib(type=int, default=None, init=False) # the address in the zero page if this var is there, will be set later
@property
def value(self) -> Expression:
def value(self) -> Union[LiteralValue, Expression]:
return self.nodes[0] if self.nodes else None # type: ignore
@value.setter
def value(self, newvalue: Expression) -> None:
def value(self, value: Union[LiteralValue, Expression]) -> None:
assert isinstance(value, (LiteralValue, Expression))
if self.nodes:
self.nodes[0] = newvalue
self.nodes[0] = value
else:
self.nodes.append(newvalue)
self.nodes.append(value)
# if the value is an expression, mark it as a *constant* expression here
if isinstance(self.nodes[0], AstNode): # XXX expression only?
self.value.processed_expr_must_be_constant = True
if isinstance(value, Expression):
value.must_be_constant = True
def __attrs_post_init__(self):
# convert vartype to enum
@ -588,24 +584,13 @@ class VarDef(AstNode):
if self.datatype.isarray() and sum(self.size) in (0, 1):
print("warning: {}: array/matrix with size 1, use normal byte/word instead for efficiency".format(self.sourceref))
if self.value is None and (self.datatype.isnumeric() or self.datatype.isarray()):
self.value = 0
self.value = LiteralValue(value=0, sourceref=self.sourceref)
# if it's a matrix with interleave, it must be memory mapped
if self.datatype == DataType.MATRIX and len(self.size) == 3:
if self.vartype != VarType.MEMORY:
raise ParseError("matrix with interleave can only be a memory-mapped variable", self.sourceref)
# note: value coercion is done later, when all expressions are evaluated
def process_expressions(self, scope: Scope) -> None:
self.value = process_expression(self.value, scope, self.sourceref)
assert not isinstance(self.value, Expression), "processed expression for vardef should reduce to a constant value"
if self.vartype in (VarType.CONST, VarType.VAR):
try:
_, self.value = coerce_constant_value(self.datatype, self.value, self.sourceref)
except OverflowError as x:
raise ParseError(str(x), self.sourceref) from None
except TypeError as x:
raise ParseError("processed expression vor vardef is not a constant value: " + str(x), self.sourceref) from None
@attr.s(cmp=False, repr=False)
class Return(AstNode):
@ -622,29 +607,6 @@ class Return(AstNode):
def value_Y(self) -> Expression:
return self.nodes[0] # type: ignore
def process_expressions(self, scope: Scope) -> None:
if self.nodes[0] is not None:
self.nodes[0] = process_expression(self.nodes[0], scope, self.sourceref)
if isinstance(self.nodes[0], (int, float, str, bool)):
try:
_, self.nodes[0] = coerce_constant_value(DataType.BYTE, self.nodes[0], self.sourceref)
except (OverflowError, TypeError) as x:
raise ParseError("first value (A): " + str(x), self.sourceref) from None
if self.nodes[1] is not None:
self.nodes[1] = process_expression(self.nodes[1], scope, self.sourceref)
if isinstance(self.nodes[1], (int, float, str, bool)):
try:
_, self.nodes[1] = coerce_constant_value(DataType.BYTE, self.nodes[1], self.sourceref)
except (OverflowError, TypeError) as x:
raise ParseError("second value (X): " + str(x), self.sourceref) from None
if self.nodes[2] is not None:
self.nodes[2] = process_expression(self.nodes[2], scope, self.sourceref)
if isinstance(self.nodes[2], (int, float, str, bool)):
try:
_, self.nodes[2] = coerce_constant_value(DataType.BYTE, self.nodes[2], self.sourceref)
except (OverflowError, TypeError) as x:
raise ParseError("third value (Y): " + str(x), self.sourceref) from None
@attr.s(cmp=False, slots=True, repr=False)
class AssignmentTargets(AstNode):
@ -667,11 +629,9 @@ class Assignment(AstNode):
@right.setter
def right(self, rvalue: Union[LiteralValue, Expression]) -> None:
assert isinstance(rvalue, (LiteralValue, Expression))
self.nodes[1] = rvalue
def process_expressions(self, scope: Scope) -> None:
self.nodes[1] = process_expression(self.nodes[1], scope, self.nodes[1].sourceref)
@attr.s(cmp=False, slots=True, repr=False)
class AugAssignment(AstNode):
@ -686,9 +646,6 @@ class AugAssignment(AstNode):
def right(self) -> Expression:
return self.nodes[1] # type: ignore
def process_expressions(self, scope: Scope) -> None:
self.nodes[1] = process_expression(self.nodes[1], scope, self.right.sourceref)
def datatype_of(assignmenttarget: AstNode, scope: Scope) -> DataType:
# tries to determine the DataType of an assignment target node
@ -749,26 +706,20 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
return False, value
def process_expression(value: Any, scope: Scope, sourceref: SourceRef) -> Any:
def process_expression(expr: Expression, scope: Scope, sourceref: SourceRef) -> Any:
# process/simplify all expressions (constant folding etc)
if isinstance(value, AstNode):
must_be_constant = value.processed_expr_must_be_constant
if expr.must_be_constant:
return process_constant_expression(expr, sourceref, scope)
else:
must_be_constant = False
if must_be_constant:
return process_constant_expression(value, sourceref, scope)
else:
return process_dynamic_expression(value, sourceref, scope)
return process_dynamic_expression(expr, sourceref, scope)
def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever)
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr)
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, LiteralValue):
return expr.value
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, symbolscope, expr.sourceref)
if isinstance(value, VarDef):
@ -809,7 +760,7 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return func(*func_args)
return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
@ -844,7 +795,7 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
expr.right = process_constant_expression(expr.right, right_sourceref, symbolscope)
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(symbolscope)
return expr.evaluate_primitive_constants(symbolscope, expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
@ -853,19 +804,12 @@ def process_constant_expression(expr: Any, sourceref: SourceRef, symbolscope: Sc
.format(expr.left.__class__.__name__), left_sourceref)
def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any:
try:
return scope.lookup(name)
except UndefinedSymbolError as x:
raise ParseError(str(x), sref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Scope) -> Any:
# constant-fold a dynamic expression
if expr is None or isinstance(expr, (int, float, str, bool)):
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, LiteralValue):
return expr.value
elif isinstance(expr, SymbolName):
try:
return process_constant_expression(expr, sourceref, symbolscope)
@ -909,6 +853,13 @@ def process_dynamic_expression(expr: Any, sourceref: SourceRef, symbolscope: Sco
return expr
def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any:
try:
return scope.lookup(name)
except UndefinedSymbolError as x:
raise ParseError(str(x), sref)
# ----------------- PLY parser definition follows ----------------------
def p_start(p):
@ -1401,14 +1352,23 @@ def p_aug_assignment(p):
"""
p[0] = AugAssignment(operator=p[2], sourceref=_token_sref(p, 2))
p[0].nodes.append(p[1])
p[0].nodes.append(p[2])
p[0].nodes.append(p[3])
precedence = (
('left', '+', '-'),
('left', '*', '/', 'INTEGERDIVIDE'),
('right', 'UNARY_MINUS', 'BITINVERT', "UNARY_ADDRESSOF"),
# following the python operator precedence rules mostly; https://docs.python.org/3/reference/expressions.html#operator-precedence
('left', 'LOGICOR'),
('left', 'LOGICAND'),
('right', 'LOGICNOT'),
('left', "LT", "GT", "LE", "GE", "EQUALS", "NOTEQUALS"),
('left', 'BITOR'),
('left', 'BITXOR'),
('left', 'BITAND'),
('left', 'SHIFTLEFT', 'SHIFTRIGHT'),
('left', '+', '-'),
('left', '*', '/', 'INTEGERDIVIDE', 'MODULO'),
('right', 'UNARY_MINUS', 'BITINVERT', "UNARY_ADDRESSOF"),
('left', 'POWER'),
('nonassoc', "COMMENT"),
)
@ -1419,6 +1379,15 @@ def p_expression(p):
| expression '-' expression
| expression '*' expression
| expression '/' expression
| expression MODULO expression
| expression BITOR expression
| expression BITXOR expression
| expression BITAND expression
| expression SHIFTLEFT expression
| expression SHIFTRIGHT expression
| expression LOGICOR expression
| expression LOGICAND expression
| expression POWER expression
| expression INTEGERDIVIDE expression
| expression LT expression
| expression GT expression
@ -1451,6 +1420,13 @@ def p_unary_expression_bitinvert(p):
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_unary_expression_logicnot(p):
"""
expression : LOGICNOT expression
"""
p[0] = Expression(left=p[2], operator=p[1], right=None, unary=True, sourceref=_token_sref(p, 1))
def p_expression_group(p):
"""
expression : '(' expression ')'