more optimizations

This commit is contained in:
Irmen de Jong 2018-01-29 21:10:06 +01:00
parent 2a662ba256
commit f82ceab969
9 changed files with 230 additions and 207 deletions

View File

@ -25,7 +25,7 @@ which aims to provide many conveniences over raw assembly code (even when using
- breakpoints, that let the Vice emulator drop into the monitor if execution hits them
- source code labels automatically loaded in Vice emulator so it can show them in disassembly
- conditional gotos
- some code optimizations (such as not repeatedly loading the same value in a register)
- various code optimizations (code structure, logical and numerical expressions, ...)
- @todo: loops
- @todo: memory block operations

View File

@ -11,10 +11,7 @@ import sys
import linecache
from typing import Optional, Tuple, Set, Dict, List, Any, no_type_check
import attr
from .plyparse import parse_file, ParseError, Module, Directive, Block, Subroutine, Scope, VarDef, LiteralValue, \
SubCall, Goto, Return, Assignment, InlineAssembly, Register, Expression, ProgramFormat, ZpOptions,\
SymbolName, Dereference, AddressOf, IncrDecr, AstNode, datatype_of, coerce_constant_value, \
check_symbol_definition, UndefinedSymbolError, process_expression, AugAssignment
from .plyparse import *
from .plylex import SourceRef, print_bold
from .datatypes import DataType, VarType
@ -38,7 +35,7 @@ class PlyParser:
self.check_all_symbolnames(module)
self.create_multiassigns(module)
self.check_and_merge_zeropages(module)
self.process_all_expressions(module)
self.simplify_some_assignments(module)
if not self.imported_module:
# the following shall only be done on the main module after all imports have been done:
self.apply_directive_options(module)
@ -75,7 +72,14 @@ class PlyParser:
# perform semantic analysis / checks on the syntactic parse tree we have so far
# (note: symbol names have already been checked to exist when we start this)
previous_stmt = None
encountered_blocks = set() # type: Set[Block]
for node in module.all_nodes():
if isinstance(node, Block):
parentname = (node.parent.name + ".") if node.parent else ""
blockname = parentname + node.name
if blockname in encountered_blocks:
raise ValueError("block names not unique:", blockname)
encountered_blocks.add(blockname)
if isinstance(node, Scope):
if node.nodes and isinstance(node.parent, (Block, Subroutine)):
if isinstance(node.parent, Block) and node.parent.name != "ZP":
@ -170,27 +174,11 @@ class PlyParser:
for node in module.all_nodes(SymbolName):
check_symbol_definition(node.name, node.my_scope(), node.sourceref) # type: ignore
def process_all_expressions(self, module: Module) -> None:
# process/simplify all expressions (constant folding etc)
encountered_blocks = set() # type: Set[Block]
def simplify_some_assignments(self, module: Module) -> None:
# simplify some assignment statements,
# note taht most of the expression optimization (constant folding etc) is done in the optimizer.
for node in module.all_nodes():
if isinstance(node, Block):
parentname = (node.parent.name + ".") if node.parent else ""
blockname = parentname + node.name
if blockname in encountered_blocks:
raise ValueError("block names not unique:", blockname)
encountered_blocks.add(blockname)
elif isinstance(node, Expression):
try:
evaluated = process_expression(node, node.sourceref)
if evaluated is not node:
# replace the node with the newly evaluated result
node.parent.replace_node(node, evaluated)
except ParseError:
raise
except Exception as x:
self.handle_internal_error(x, "process_expressions of node {}".format(node))
elif isinstance(node, IncrDecr) and node.howmuch not in (0, 1):
if isinstance(node, IncrDecr) and node.howmuch not in (0, 1):
_, node.howmuch = coerce_constant_value(datatype_of(node.target, node.my_scope()), node.howmuch, node.sourceref)
attr.validate(node)
elif isinstance(node, VarDef):
@ -485,17 +473,6 @@ class PlyParser:
print("\x1b[0m", file=out, end="", flush=True)
raise exc # XXX temporary to see where the error occurred
def handle_internal_error(self, exc: Exception, msg: str="") -> None:
out = sys.stdout
if out.isatty():
print("\x1b[1m", file=out)
print("\nERROR: internal parser error: ", exc, file=out)
if msg:
print(" Message:", msg, end="\n\n")
if out.isatty():
print("\x1b[0m", file=out, end="", flush=True)
raise exc
class Zeropage:
SCRATCH_B1 = 0x02

View File

@ -5,11 +5,8 @@ Here are the data type definitions and -conversions.
Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
"""
import math
import enum
from typing import Tuple, Union
from functools import total_ordering
from .plylex import print_warning, SourceRef
@total_ordering

View File

@ -193,11 +193,13 @@ def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None:
out("\vjsr c64flt.float_add_one")
else:
out("\vjsr c64flt.float_sub_one")
elif NOTYETIMPLEMENTED: # XXX for the float += otherfloat cases
else:
# XXX for the float += otherfloat cases
print("FLOAT INCR/DECR BY", stmt.howmuch) # XXX
with preserving_registers({'A', 'X', 'Y'}, scope, out, loads_a_within=True):
out("\vlda #<" + stmt.value.name)
# XXX out("\vlda #<" + stmt.value.name)
out("\vsta c64.SCRATCH_ZPWORD1")
out("\vlda #>" + stmt.value.name)
# XXX out("\vlda #>" + stmt.value.name)
out("\vsta c64.SCRATCH_ZPWORD1+1")
out("\vldx #<" + what_str)
out("\vldy #>" + what_str)
@ -205,8 +207,6 @@ def generate_incrdecr(out: Callable, stmt: IncrDecr, scope: Scope) -> None:
out("\vjsr c64flt.float_add_SW1_to_XY")
else:
out("\vjsr c64flt.float_sub_SW1_from_XY")
else:
raise CodeError("incr/decr missing float constant definition")
else:
raise CodeError("cannot in/decrement memory of type " + str(target.datatype), stmt.howmuch)

View File

@ -129,6 +129,7 @@ def generate_block_vars(out: Callable, block: Block, zeropage: bool=False) -> No
_generate_string_var(out, vardef)
else:
raise CodeError("invalid const type", vardef)
# @todo float constants that are used in expressions
out("; memory mapped variables")
for vardef in vars_by_vartype.get(VarType.MEMORY, []):
# create a definition for variables at a specific place in memory (memory-mapped)

View File

@ -83,7 +83,7 @@ def main() -> None:
parsed_module = parser.parse_file(args.sourcefile)
if parsed_module:
if args.nooptimize:
print_bold("not optimizing the parse tree!")
print_bold("Optimizations disabled!")
else:
print("\nOptimizing code.")
optimize(parsed_module)

View File

@ -6,21 +6,34 @@ eliminates statements that have no effect, optimizes calculations etc.
Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
"""
from typing import List, no_type_check, Union
from .plyparse import AstNode, Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr,\
datatype_of, coerce_constant_value, AssignmentTargets, LiteralValue, Scope, Register, SymbolName, \
Dereference, TargetRegisters, VarDef
import sys
from typing import List, no_type_check, Union, Any
from .plyparse import *
from .plylex import print_warning, print_bold, SourceRef
from .datatypes import DataType
from .datatypes import DataType, VarType
class Optimizer:
def __init__(self, mod: Module) -> None:
self.num_warnings = 0
self.module = mod
self.optimizations_performed = False
def optimize(self) -> None:
self.num_warnings = 0
self.optimizations_performed = True
# keep optimizing as long as there were changes made
while self.optimizations_performed:
self.optimizations_performed = False
self._optimize()
# remaining optimizations that have to be done just once:
self.remove_unused_subroutines()
self.remove_empty_blocks()
def _optimize(self) -> None:
self.constant_folding()
# @todo expression optimization: reduce expression nesting
# @todo expression optimization: simplify logical expression when a term makes it always true or false
self.create_aug_assignments()
self.optimize_assignments()
self.remove_superfluous_assignments()
@ -28,11 +41,32 @@ class Optimizer:
self.optimize_multiassigns()
# @todo optimize some simple multiplications into shifts (A*=8 -> A<<3)
# @todo optimize addition with self into shift 1 (A+=A -> A<<=1)
self.remove_unused_subroutines()
self.optimize_goto_compare_with_zero()
self.join_incrdecrs()
# @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to)
self.remove_empty_blocks()
def handle_internal_error(self, exc: Exception, msg: str="") -> None:
out = sys.stdout
if out.isatty():
print("\x1b[1m", file=out)
print("\nERROR: internal parser/optimizer error: ", exc, file=out)
if msg:
print(" Message:", msg, end="\n\n")
if out.isatty():
print("\x1b[0m", file=out, end="", flush=True)
raise exc
def constant_folding(self) -> None:
for expression in self.module.all_nodes(Expression):
try:
evaluated = process_expression(expression, expression.sourceref) # type: ignore
if evaluated is not expression:
# replace the node with the newly evaluated result
expression.parent.replace_node(expression, evaluated)
except ParseError:
raise
except Exception as x:
self.handle_internal_error(x, "process_expressions of node {}".format(expression))
def join_incrdecrs(self) -> None:
for scope in self.module.all_nodes(Scope):
@ -88,6 +122,7 @@ class Optimizer:
incrdecr = self._make_incrdecr(incrdecrs[0], target, total, "--")
scope.replace_node(incrdecrs[0], incrdecr)
if replaced:
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: merged a sequence of incr/decrs or augmented assignments".format(incrdecrs[0].sourceref))
incrdecrs.clear()
@ -122,6 +157,7 @@ class Optimizer:
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True
continue
if expr.operator not in ('+', '*', '|', '^'): # associative operators
continue
@ -130,11 +166,13 @@ class Optimizer:
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True
elif isinstance(expr.left, (LiteralValue, SymbolName)) and self._same_target(assignment.left.nodes[0], expr.right):
num_val = expr.left.const_num_val()
operator = expr.operator + '='
aug_assign = self._make_aug_assign(assignment, assignment.left.nodes[0], num_val, operator)
assignment.my_scope().replace_node(assignment, aug_assign)
self.optimizations_performed = True
def remove_superfluous_assignments(self) -> None:
# remove consecutive assignment statements to the same target, only keep the last value (only if its a constant!)
@ -146,6 +184,7 @@ class Optimizer:
if isinstance(node.right, (LiteralValue, Register)) and node.left.same_targets(prev_node.left):
if not node.left.has_memvalue():
scope.remove_node(prev_node)
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: removed superfluous assignment".format(prev_node.sourceref))
prev_node = node
@ -160,6 +199,7 @@ class Optimizer:
if isinstance(assignment, Assignment):
if all(lv == assignment.right for lv in assignment.left.nodes):
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
if isinstance(assignment, AugAssignment):
@ -169,22 +209,26 @@ class Optimizer:
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
elif assignment.operator == "*=":
self.num_warnings += 1
print_warning("{}: statement replaced by = 0".format(assignment.sourceref))
new_assignment = self._make_new_assignment(assignment, 0)
assignment.my_scope().replace_node(assignment, new_assignment)
self.optimizations_performed = True
elif assignment.operator == "**=":
self.num_warnings += 1
print_warning("{}: statement replaced by = 1".format(assignment.sourceref))
new_assignment = self._make_new_assignment(assignment, 1)
assignment.my_scope().replace_node(assignment, new_assignment)
self.optimizations_performed = True
if assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="):
print("{}: shifting result is always zero".format(assignment.sourceref))
new_stmt = Assignment(sourceref=assignment.sourceref)
new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref))
new_stmt.nodes.append(LiteralValue(value=0, sourceref=assignment.sourceref))
assignment.my_scope().replace_node(assignment, new_stmt)
self.optimizations_performed = True
if assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256:
howmuch = assignment.right
if howmuch.value not in (0, 1):
@ -194,10 +238,12 @@ class Optimizer:
howmuch=howmuch.value, sourceref=assignment.sourceref)
new_stmt.target = assignment.left
assignment.my_scope().replace_node(assignment, new_stmt)
self.optimizations_performed = True
if assignment.right.value == 1 and assignment.operator in ("/=", "//=", "*="):
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
assignment.my_scope().remove_node(assignment)
self.optimizations_performed = True
@no_type_check
def _make_new_assignment(self, old_aug_assignment: AugAssignment, constantvalue: int) -> Assignment:
@ -222,7 +268,7 @@ class Optimizer:
@no_type_check
def _make_incrdecr(self, old_stmt: AstNode, target: Union[TargetRegisters, Register, SymbolName, Dereference],
howmuch: Union[int, float], operator: str) -> AugAssignment:
howmuch: Union[int, float], operator: str) -> IncrDecr:
a = IncrDecr(operator=operator, howmuch=howmuch, sourceref=old_stmt.sourceref)
a.nodes.append(target)
a.parent = old_stmt.parent
@ -245,6 +291,7 @@ class Optimizer:
print("{}: joined with previous assignment".format(assignment.sourceref))
assignments[0].left.nodes.extend(assignment.left.nodes)
scope.remove_node(assignment)
self.optimizations_performed = True
rvalue = None
assignments.clear()
else:
@ -263,8 +310,9 @@ class Optimizer:
lvalues = set(assignment.left.nodes)
if len(lvalues) != len(assignment.left.nodes):
print("{}: removed duplicate assignment targets".format(assignment.sourceref))
# @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any)
assignment.left.nodes = list(lvalues)
# @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any)
assignment.left.nodes = list(lvalues)
self.optimizations_performed = True
@no_type_check
def remove_unused_subroutines(self) -> None:
@ -337,6 +385,155 @@ class Optimizer:
node.my_scope().nodes.remove(node)
def process_expression(expr: Expression, sourceref: SourceRef) -> Any:
# process/simplify all expressions (constant folding etc)
if expr.must_be_constant:
return process_constant_expression(expr, sourceref)
else:
return process_dynamic_expression(expr, sourceref)
def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, Expression):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue):
return value
elif isinstance(value, (int, float, str, bool)):
raise TypeError("symbol value node should not be a python primitive value", expr)
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, AddressOf):
assert isinstance(expr.name, SymbolName)
value = check_symbol_definition(expr.name.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
if isinstance(value.value, LiteralValue):
return value.value
else:
raise ExpressionEvaluationError("constant literal value required", value.sourceref)
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.name.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.name.sourceref)
elif isinstance(expr, SubCall):
if isinstance(expr.target, SymbolName): # 'function(1,2,3)'
funcname = expr.target.name
if funcname in math_functions or funcname in builtin_functions:
func_args = []
for a in (process_constant_expression(callarg.value, sourceref) for callarg in expr.arguments.nodes):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref)
elif isinstance(expr.target, Dereference): # '[...](1,2,3)'
raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref)
elif type(expr.target) is int: # '64738()'
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else:
raise NotImplementedError("weird call target", expr.target)
elif not isinstance(expr, Expression):
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_constant_expression(expr.right, right_sourceref)
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
# constant-fold a dynamic expression
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, SymbolName):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
elif isinstance(expr, AddressOf):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
elif isinstance(expr, SubCall):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr
elif isinstance(expr, Register):
return expr
elif isinstance(expr, Dereference):
if isinstance(expr.operand, SymbolName):
check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref)
return expr
elif not isinstance(expr, Expression):
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_dynamic_expression(expr.right, right_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
def optimize(mod: Module) -> None:
opt = Optimizer(mod)
opt.optimize()

View File

@ -10,7 +10,7 @@ import builtins
import inspect
import enum
from collections import defaultdict
from typing import Union, Generator, Tuple, Sequence, List, Optional, Dict, Any, no_type_check
from typing import Union, Generator, Tuple, List, Optional, Dict, Any, no_type_check
import attr
from ply.yacc import yacc
from .plylex import SourceRef, tokens, lexer, find_tok_column, print_warning
@ -809,155 +809,6 @@ def coerce_constant_value(datatype: DataType, value: AstNode,
return False, value
def process_expression(expr: Expression, sourceref: SourceRef) -> Any:
# process/simplify all expressions (constant folding etc)
if expr.must_be_constant:
return process_constant_expression(expr, sourceref)
else:
return process_dynamic_expression(expr, sourceref)
def process_constant_expression(expr: Any, sourceref: SourceRef) -> LiteralValue:
# the expression must result in a single (constant) value (int, float, whatever) wrapped as LiteralValue.
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, SymbolName):
value = check_symbol_definition(expr.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
raise ExpressionEvaluationError("can't take a memory value, must be a constant", expr.sourceref)
value = value.value
if isinstance(value, Expression):
raise ExpressionEvaluationError("circular reference?", expr.sourceref)
elif isinstance(value, LiteralValue):
return value
elif isinstance(value, (int, float, str, bool)):
raise TypeError("symbol value node should not be a python primitive value", expr)
else:
raise ExpressionEvaluationError("constant symbol required, not {}".format(value.__class__.__name__), expr.sourceref)
elif isinstance(expr, AddressOf):
assert isinstance(expr.name, SymbolName)
value = check_symbol_definition(expr.name.name, expr.my_scope(), expr.sourceref)
if isinstance(value, VarDef):
if value.vartype == VarType.MEMORY:
if isinstance(value.value, LiteralValue):
return value.value
else:
raise ExpressionEvaluationError("constant literal value required", value.sourceref)
if value.vartype == VarType.CONST:
raise ExpressionEvaluationError("can't take the address of a constant", expr.name.sourceref)
raise ExpressionEvaluationError("address-of this {} isn't a compile-time constant"
.format(value.__class__.__name__), expr.name.sourceref)
else:
raise ExpressionEvaluationError("constant address required, not {}"
.format(value.__class__.__name__), expr.name.sourceref)
elif isinstance(expr, SubCall):
if isinstance(expr.target, SymbolName): # 'function(1,2,3)'
funcname = expr.target.name
if funcname in math_functions or funcname in builtin_functions:
func_args = []
for a in (process_constant_expression(callarg.value, sourceref) for callarg in expr.arguments.nodes):
if isinstance(a, LiteralValue):
func_args.append(a.value)
else:
func_args.append(a)
func = math_functions.get(funcname, builtin_functions.get(funcname))
try:
return LiteralValue(value=func(*func_args), sourceref=expr.arguments.sourceref) # type: ignore
except Exception as x:
raise ExpressionEvaluationError(str(x), expr.sourceref)
else:
raise ExpressionEvaluationError("can only use math- or builtin function", expr.sourceref)
elif isinstance(expr.target, Dereference): # '[...](1,2,3)'
raise ExpressionEvaluationError("dereferenced value call is not a constant value", expr.sourceref)
elif type(expr.target) is int: # '64738()'
raise ExpressionEvaluationError("immediate address call is not a constant value", expr.sourceref)
else:
raise NotImplementedError("weird call target", expr.target)
elif not isinstance(expr, Expression):
raise ExpressionEvaluationError("constant value required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
if isinstance(expr.left, LiteralValue) and type(expr.left.value) in (int, float):
try:
if expr.operator == '-':
return LiteralValue(value=-expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator == '~':
return LiteralValue(value=~expr.left.value, sourceref=expr.left.sourceref) # type: ignore
elif expr.operator in ("++", "--"):
raise ValueError("incr/decr should not be an expression")
raise ValueError("invalid unary operator", expr.operator)
except TypeError as x:
raise ParseError(str(x), expr.sourceref) from None
raise ValueError("invalid operand type for unary operator", expr.left, expr.operator)
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_constant_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_constant_expression(expr.right, right_sourceref)
if isinstance(expr.left, LiteralValue):
if isinstance(expr.right, LiteralValue):
return expr.evaluate_primitive_constants(expr.right.sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on right, not {}"
.format(expr.right.__class__.__name__), right_sourceref)
else:
raise ExpressionEvaluationError("constant literal value required on left, not {}"
.format(expr.left.__class__.__name__), left_sourceref)
def process_dynamic_expression(expr: Any, sourceref: SourceRef) -> Any:
# constant-fold a dynamic expression
if isinstance(expr, (int, float, str, bool)):
raise TypeError("expr node should not be a python primitive value", expr, sourceref)
elif expr is None or isinstance(expr, LiteralValue):
return expr
elif isinstance(expr, SymbolName):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
elif isinstance(expr, AddressOf):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
elif isinstance(expr, SubCall):
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
if isinstance(expr.target, SymbolName):
check_symbol_definition(expr.target.name, expr.my_scope(), expr.target.sourceref)
return expr
elif isinstance(expr, Register):
return expr
elif isinstance(expr, Dereference):
if isinstance(expr.operand, SymbolName):
check_symbol_definition(expr.operand.name, expr.my_scope(), expr.operand.sourceref)
return expr
elif not isinstance(expr, Expression):
raise ParseError("expression required, not {}".format(expr.__class__.__name__), expr.sourceref)
if expr.unary:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
else:
left_sourceref = expr.left.sourceref if isinstance(expr.left, AstNode) else sourceref
expr.left = process_dynamic_expression(expr.left, left_sourceref)
right_sourceref = expr.right.sourceref if isinstance(expr.right, AstNode) else sourceref
expr.right = process_dynamic_expression(expr.right, right_sourceref)
try:
return process_constant_expression(expr, sourceref)
except ExpressionEvaluationError:
return expr
def check_symbol_definition(name: str, scope: Scope, sref: SourceRef) -> Any:
try:
return scope.lookup(name)

View File

@ -25,7 +25,7 @@ which aims to provide many conveniences over raw assembly code (even when using
- breakpoints, that let the Vice emulator drop into the monitor if execution hits them
- source code labels automatically loaded in Vice emulator so it can show them in disassembly
- conditional gotos
- some code optimizations (such as not repeatedly loading the same value in a register)
- various code optimizations (code structure, logical and numerical expressions, ...)
- @todo: loops
- @todo: memory block operations