prog8/il65/astparse.py
2017-12-28 22:27:13 +01:00

289 lines
11 KiB
Python

"""
Programming Language for 6502/6510 microprocessors
This is the expression parser/evaluator.
Written by Irmen de Jong (irmen@razorvine.net)
License: GNU GPL 3.0, see LICENSE
"""
import ast
from typing import Union, Optional, List, Tuple, Any
from .symbols import FLOAT_MAX_POSITIVE, FLOAT_MAX_NEGATIVE, SourceRef, SymbolTable, SymbolError, PrimitiveType
class ParseError(Exception):
def __init__(self, message: str, sourcetext: str, sourceref: SourceRef) -> None:
self.sourceref = sourceref
self.msg = message
self.sourcetext = sourcetext
def __str__(self):
return "{} {:s}".format(self.sourceref, self.msg)
class SourceLine:
def __init__(self, text: str, sourceref: SourceRef) -> None:
self.sourceref = sourceref
self.text = text.strip()
def to_error(self, message: str) -> ParseError:
return ParseError(message, self.text, self.sourceref)
def preprocess(self) -> str:
# transforms the source text into valid Python syntax by bending some things, so ast can parse it.
# $d020 -> 0xd020
# %101001 -> 0xb101001
# #something -> __ptr@something (matmult operator)
text = ""
quotes_stack = ""
characters = enumerate(self.text + " ")
for i, c in characters:
if c in ("'", '"'):
if quotes_stack and quotes_stack[-1] == c:
quotes_stack = quotes_stack[:-1]
else:
quotes_stack += c
text += c
continue
if not quotes_stack:
if c == '%' and self.text[i + 1] in "01":
text += "0b"
continue
if c == '$' and self.text[i + 1] in "0123456789abcdefABCDEF":
text += "0x"
continue
if c == '#':
if i > 0:
text += " "
text += "__ptr@"
continue
text += c
return text
def parse_arguments(text: str, sourceref: SourceRef) -> List[Tuple[str, PrimitiveType]]:
src = SourceLine(text, sourceref)
text = src.preprocess()
try:
nodes = ast.parse("__func({:s})".format(text), sourceref.file, "eval")
except SyntaxError as x:
raise src.to_error(str(x))
args = [] # type: List[Tuple[str, Any]]
if isinstance(nodes, ast.Expression):
for arg in nodes.body.args:
reprvalue = astnode_to_repr(arg)
args.append((None, reprvalue))
for kwarg in nodes.body.keywords:
reprvalue = astnode_to_repr(kwarg.value)
args.append((kwarg.arg, reprvalue))
return args
else:
raise TypeError("ast.Expression expected")
def parse_expr_as_comparison(text: str, sourceref: SourceRef) -> Tuple[str, str, str]:
src = SourceLine(text, sourceref)
text = src.preprocess()
try:
node = ast.parse(text, sourceref.file, mode="eval")
except SyntaxError as x:
raise src.to_error(str(x))
if not isinstance(node, ast.Expression):
raise TypeError("ast.Expression expected")
if isinstance(node.body, ast.Compare):
if len(node.body.ops) != 1:
raise src.to_error("only one comparison operator at a time is supported")
operator = {
"Eq": "==",
"NotEq": "!=",
"Lt": "<",
"LtE": "<=",
"Gt": ">",
"GtE": ">=",
"Is": None,
"IsNot": None,
"In": None,
"NotIn": None
}[node.body.ops[0].__class__.__name__]
if not operator:
raise src.to_error("unsupported comparison operator")
left = text[node.body.left.col_offset:node.body.comparators[0].col_offset].rstrip()[:-len(operator)]
right = text[node.body.comparators[0].col_offset:]
return left.strip(), operator, right.strip()
left = astnode_to_repr(node.body)
return left, "", ""
def parse_expr_as_int(text: str, context: Optional[SymbolTable], ppcontext: Optional[SymbolTable], sourceref: SourceRef, *,
minimum: int=0, maximum: int=0xffff) -> int:
result = parse_expr_as_primitive(text, context, ppcontext, sourceref, minimum=minimum, maximum=maximum)
if isinstance(result, int):
return result
src = SourceLine(text, sourceref)
raise src.to_error("int expected, not " + type(result).__name__)
def parse_expr_as_number(text: str, context: Optional[SymbolTable], ppcontext: Optional[SymbolTable], sourceref: SourceRef, *,
minimum: float=FLOAT_MAX_NEGATIVE, maximum: float=FLOAT_MAX_POSITIVE) -> Union[int, float]:
result = parse_expr_as_primitive(text, context, ppcontext, sourceref, minimum=minimum, maximum=maximum)
if isinstance(result, (int, float)):
return result
src = SourceLine(text, sourceref)
raise src.to_error("int or float expected, not " + type(result).__name__)
def parse_expr_as_string(text: str, context: Optional[SymbolTable], ppcontext: Optional[SymbolTable], sourceref: SourceRef) -> str:
result = parse_expr_as_primitive(text, context, ppcontext, sourceref)
if isinstance(result, str):
return result
src = SourceLine(text, sourceref)
raise src.to_error("string expected, not " + type(result).__name__)
def parse_expr_as_primitive(text: str, context: Optional[SymbolTable], ppcontext: Optional[SymbolTable], sourceref: SourceRef, *,
minimum: float = FLOAT_MAX_NEGATIVE, maximum: float = FLOAT_MAX_POSITIVE) -> PrimitiveType:
src = SourceLine(text, sourceref)
text = src.preprocess()
try:
node = ast.parse(text, sourceref.file, mode="eval")
except SyntaxError as x:
raise src.to_error(str(x))
if isinstance(node, ast.Expression):
result = ExpressionTransformer(src, context, ppcontext).evaluate(node)
else:
raise TypeError("ast.Expression expected")
if isinstance(result, bool):
return int(result)
if isinstance(result, (int, float)):
if minimum <= result <= maximum:
return result
raise src.to_error("number too large")
if isinstance(result, str):
return result
raise src.to_error("int or float or string expected, not " + type(result).__name__)
class EvaluatingTransformer(ast.NodeTransformer):
def __init__(self, src: SourceLine, context: SymbolTable, ppcontext: SymbolTable) -> None:
super().__init__()
self.src = src
self.context = context
self.ppcontext = ppcontext
def error(self, message: str, column: int=0) -> ParseError:
if column:
ref = self.src.sourceref.copy()
ref.column = column
else:
ref = self.src.sourceref
return ParseError(message, self.src.text, ref)
def evaluate(self, node: ast.Expression) -> PrimitiveType:
node = self.visit(node)
code = compile(node, self.src.sourceref.file, mode="eval")
if self.context:
globals = None
locals = self.context.as_eval_dict(self.ppcontext)
else:
globals = {"__builtins__": {}}
locals = None
try:
result = eval(code, globals, locals) # XXX unsafe...
except Exception as x:
raise self.src.to_error(str(x)) from x
else:
if type(result) is bool:
return int(result)
return result
class ExpressionTransformer(EvaluatingTransformer):
def _dotted_name_from_attr(self, node: ast.Attribute) -> str:
if isinstance(node.value, ast.Name):
return node.value.id + '.' + node.attr
if isinstance(node.value, ast.Attribute):
return self._dotted_name_from_attr(node.value) + '.' + node.attr
raise self.error("dotted name error")
def visit_Name(self, node: ast.Name):
# convert true/false names to True/False constants
if node.id == "true":
return ast.copy_location(ast.NameConstant(True), node)
if node.id == "false":
return ast.copy_location(ast.NameConstant(False), node)
return node
def visit_UnaryOp(self, node):
if isinstance(node.operand, ast.Num):
if isinstance(node.op, ast.USub):
node = self.generic_visit(node)
return ast.copy_location(ast.Num(-node.operand.n), node)
if isinstance(node.op, ast.UAdd):
node = self.generic_visit(node)
return ast.copy_location(ast.Num(node.operand.n), node)
if isinstance(node.op, ast.Invert):
if isinstance(node.operand, ast.Num):
node = self.generic_visit(node)
return ast.copy_location(ast.Num(~node.operand.n), node)
else:
raise self.error("can only bitwise invert a number")
raise self.error("expected unary + or - or ~")
elif isinstance(node.operand, ast.UnaryOp):
# nested unary ops, for instance: "~-2" = invert(minus(2))
node = self.generic_visit(node)
return self.visit_UnaryOp(node)
else:
print(node.operand)
raise self.error("expected constant numeric operand for unary operator")
def visit_BinOp(self, node):
node = self.generic_visit(node)
if isinstance(node.op, ast.MatMult):
if isinstance(node.left, ast.Name) and node.left.id == "__ptr":
if isinstance(node.right, ast.Attribute):
symbolname = self._dotted_name_from_attr(node.right)
elif isinstance(node.right, ast.Name):
symbolname = node.right.id
else:
raise self.error("can only take address of a named variable")
try:
address = self.context.get_address(symbolname)
except SymbolError as x:
raise self.error(str(x))
else:
return ast.copy_location(ast.Num(address), node)
else:
raise self.error("invalid MatMult/Pointer node in AST")
return node
def astnode_to_repr(node: ast.AST) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Num):
return repr(node.n)
if isinstance(node, ast.Str):
return repr(node.s)
if isinstance(node, ast.BinOp):
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): # type: ignore
return '#' + astnode_to_repr(node.right)
else:
print("error", ast.dump(node))
raise TypeError("invalid arg ast node type", node)
if isinstance(node, ast.Attribute):
return astnode_to_repr(node.value) + "." + node.attr
if isinstance(node, ast.UnaryOp):
if isinstance(node.op, ast.USub):
return "-" + astnode_to_repr(node.operand)
if isinstance(node.op, ast.UAdd):
return "+" + astnode_to_repr(node.operand)
if isinstance(node.op, ast.Invert):
return "~" + astnode_to_repr(node.operand)
if isinstance(node.op, ast.Not):
return "not " + astnode_to_repr(node.operand)
if isinstance(node, ast.List):
# indirect values get turned into a list...
return "[" + ",".join(astnode_to_repr(elt) for elt in node.elts) + "]"
raise TypeError("invalid arg ast node type", node)