This commit is contained in:
Irmen de Jong 2018-01-21 03:44:04 +01:00
parent 3ea0723c3e
commit eb58119b97
6 changed files with 87 additions and 89 deletions

View File

@ -133,7 +133,7 @@ class PlyParser:
if zpnode.name != "ZP": if zpnode.name != "ZP":
return return
zeropage = Zeropage(module.zp_options) zeropage = Zeropage(module.zp_options)
for vardef in zpnode.scope.filter_nodes(VarDef): for vardef in zpnode.all_nodes(VarDef):
if vardef.datatype.isstring(): if vardef.datatype.isstring():
raise ParseError("cannot put strings in the zeropage", vardef.sourceref) raise ParseError("cannot put strings in the zeropage", vardef.sourceref)
try: try:
@ -158,7 +158,10 @@ class PlyParser:
encountered_blocks.add(blockname) encountered_blocks.add(blockname)
elif isinstance(node, Expression): elif isinstance(node, Expression):
try: try:
process_expression(node, node.my_scope(), node.sourceref) evaluated = process_expression(node, node.my_scope(), node.sourceref)
if evaluated is not node:
# replace the node with the newly evaluated result
node.parent.replace_node(node, evaluated)
except ParseError: except ParseError:
raise raise
except Exception as x: except Exception as x:
@ -408,11 +411,11 @@ class PlyParser:
self.parse_errors += import_parse_errors self.parse_errors += import_parse_errors
else: else:
raise FileNotFoundError("missing il65lib") raise FileNotFoundError("missing il65lib")
# XXX append the imported module's contents (blocks) at the end of the current module # append the imported module's contents (blocks) at the end of the current module
# for block in (node for imported_module in imported for block in (node for imported_module in imported
# for node in imported_module.scope.nodes for node in imported_module.scope.nodes
# if isinstance(node, Block)): if isinstance(node, Block)):
# module.scope.add_node(block) module.scope.add_node(block)
def import_file(self, filename: str) -> Tuple[Module, int]: def import_file(self, filename: str) -> Tuple[Module, int]:
sub_parser = PlyParser(imported_module=True) sub_parser = PlyParser(imported_module=True)

View File

@ -52,19 +52,13 @@ class AssemblyGenerator:
out("\t.end") out("\t.end")
def sanitycheck(self) -> None: def sanitycheck(self) -> None:
start_found = False for label in self.module.all_nodes(Label):
for block, parent in self.module.all_scopes(): if label.name == "start" and label.my_scope().name == "main":
assert isinstance(block, (Module, Block, Subroutine))
for label in block.nodes:
if isinstance(label, Label) and label.name == "start" and block.name == "main":
start_found = True
break
if start_found:
break break
if not start_found: else:
print_bold("ERROR: program entry point is missing ('start' label in 'main' block)\n") print_bold("ERROR: program entry point is missing ('start' label in 'main' block)\n")
raise SystemExit(1) raise SystemExit(1)
all_blocknames = [b.name for b in self.module.scope.filter_nodes(Block)] all_blocknames = [b.name for b in self.module.all_nodes(Block)]
unique_blocknames = set(all_blocknames) unique_blocknames = set(all_blocknames)
if len(all_blocknames) != len(unique_blocknames): if len(all_blocknames) != len(unique_blocknames):
for name in unique_blocknames: for name in unique_blocknames:
@ -137,7 +131,7 @@ class AssemblyGenerator:
generate_block_vars(out, zpblock, True) generate_block_vars(out, zpblock, True)
# there's no code in the zero page block. # there's no code in the zero page block.
out("\v.pend\n") out("\v.pend\n")
for block in sorted(self.module.scope.filter_nodes(Block), key=lambda b: b.address or 0): for block in sorted(self.module.all_nodes(Block), key=lambda b: b.address or 0):
if block.name == "ZP": if block.name == "ZP":
continue # already processed continue # already processed
self.cur_block = block self.cur_block = block
@ -149,7 +143,7 @@ class AssemblyGenerator:
out("{:s}\t.proc\n".format(block.label)) out("{:s}\t.proc\n".format(block.label))
generate_block_init(out, block) generate_block_init(out, block)
generate_block_vars(out, block) generate_block_vars(out, block)
subroutines = list(sub for sub in block.scope.filter_nodes(Subroutine) if sub.address is not None) subroutines = list(sub for sub in block.all_nodes(Subroutine) if sub.address is not None)
if subroutines: if subroutines:
# these are (external) subroutines that are defined by address instead of a scope/code # these are (external) subroutines that are defined by address instead of a scope/code
out("; external subroutines") out("; external subroutines")
@ -164,7 +158,7 @@ class AssemblyGenerator:
if block.name == "main" and isinstance(stmt, Label) and stmt.name == "start": if block.name == "main" and isinstance(stmt, Label) and stmt.name == "start":
# make sure the main.start routine clears the decimal and carry flags as first steps # make sure the main.start routine clears the decimal and carry flags as first steps
out("\vcld\n\vclc\n\vclv") out("\vcld\n\vclc\n\vclv")
subroutines = list(sub for sub in block.scope.filter_nodes(Subroutine) if sub.address is None) subroutines = list(sub for sub in block.all_nodes(Subroutine) if sub.address is None)
if subroutines: if subroutines:
# these are subroutines that are defined by a scope/code # these are subroutines that are defined by a scope/code
out("; -- block subroutines") out("; -- block subroutines")

View File

@ -7,7 +7,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Callable, Any from typing import Dict, List, Callable, Any
from ..plyparse import Block, VarType, VarDef from ..plyparse import Block, VarType, VarDef, LiteralValue
from ..datatypes import DataType, STRING_DATATYPES from ..datatypes import DataType, STRING_DATATYPES
from . import to_hex, to_mflpt5, CodeError from . import to_hex, to_mflpt5, CodeError
@ -56,18 +56,18 @@ def generate_block_init(out: Callable, block: Block) -> None:
float_inits = {} float_inits = {}
prev_value_a, prev_value_x = None, None prev_value_a, prev_value_x = None, None
vars_by_datatype = defaultdict(list) # type: Dict[DataType, List[VarDef]] vars_by_datatype = defaultdict(list) # type: Dict[DataType, List[VarDef]]
for vardef in block.scope.filter_nodes(VarDef): for vardef in block.all_nodes(VarDef):
if vardef.vartype == VarType.VAR: if vardef.vartype == VarType.VAR:
vars_by_datatype[vardef.datatype].append(vardef) vars_by_datatype[vardef.datatype].append(vardef)
for bytevar in sorted(vars_by_datatype[DataType.BYTE], key=lambda vd: vd.value): for bytevar in sorted(vars_by_datatype[DataType.BYTE], key=lambda vd: vd.value):
assert type(bytevar.value) is int assert isinstance(bytevar.value, LiteralValue) and type(bytevar.value.value) is int
if bytevar.value != prev_value_a: if bytevar.value.value != prev_value_a:
out("\vlda #${:02x}".format(bytevar.value)) out("\vlda #${:02x}".format(bytevar.value.value))
prev_value_a = bytevar.value prev_value_a = bytevar.value.value
out("\vsta {:s}".format(bytevar.name)) out("\vsta {:s}".format(bytevar.name))
for wordvar in sorted(vars_by_datatype[DataType.WORD], key=lambda vd: vd.value): for wordvar in sorted(vars_by_datatype[DataType.WORD], key=lambda vd: vd.value):
assert type(wordvar.value) is int assert isinstance(wordvar.value, LiteralValue) and type(wordvar.value.value) is int
v_hi, v_lo = divmod(wordvar.value, 256) v_hi, v_lo = divmod(wordvar.value.value, 256)
if v_hi != prev_value_a: if v_hi != prev_value_a:
out("\vlda #${:02x}".format(v_hi)) out("\vlda #${:02x}".format(v_hi))
prev_value_a = v_hi prev_value_a = v_hi
@ -77,18 +77,18 @@ def generate_block_init(out: Callable, block: Block) -> None:
out("\vsta {:s}".format(wordvar.name)) out("\vsta {:s}".format(wordvar.name))
out("\vstx {:s}+1".format(wordvar.name)) out("\vstx {:s}+1".format(wordvar.name))
for floatvar in vars_by_datatype[DataType.FLOAT]: for floatvar in vars_by_datatype[DataType.FLOAT]:
assert isinstance(floatvar.value, (int, float)) assert isinstance(floatvar.value, LiteralValue) and type(floatvar.value.value) in (int, float)
fpbytes = to_mflpt5(floatvar.value) # type: ignore fpbytes = to_mflpt5(floatvar.value.value) # type: ignore
float_inits[floatvar.name] = (floatvar.name, fpbytes, floatvar.value) float_inits[floatvar.name] = (floatvar.name, fpbytes, floatvar.value)
for arrayvar in vars_by_datatype[DataType.BYTEARRAY]: for arrayvar in vars_by_datatype[DataType.BYTEARRAY]:
assert type(arrayvar.value) is int assert isinstance(arrayvar.value, LiteralValue) and type(arrayvar.value.value) is int
_memset(arrayvar.name, arrayvar.value, arrayvar.size[0]) _memset(arrayvar.name, arrayvar.value.value, arrayvar.size[0])
for arrayvar in vars_by_datatype[DataType.WORDARRAY]: for arrayvar in vars_by_datatype[DataType.WORDARRAY]:
assert type(arrayvar.value) is int assert isinstance(arrayvar.value, LiteralValue) and type(arrayvar.value.value) is int
_memsetw(arrayvar.name, arrayvar.value, arrayvar.size[0]) _memsetw(arrayvar.name, arrayvar.value.value, arrayvar.size[0])
for arrayvar in vars_by_datatype[DataType.MATRIX]: for arrayvar in vars_by_datatype[DataType.MATRIX]:
assert type(arrayvar.value) is int assert isinstance(arrayvar.value, LiteralValue) and type(arrayvar.value.value) is int
_memset(arrayvar.name, arrayvar.value, arrayvar.size[0] * arrayvar.size[1]) _memset(arrayvar.name, arrayvar.value.value, arrayvar.size[0] * arrayvar.size[1])
if float_inits: if float_inits:
out("\vldx #4") out("\vldx #4")
out("-") out("-")
@ -114,7 +114,7 @@ def generate_block_vars(out: Callable, block: Block, zeropage: bool=False) -> No
# The memory bytes of the allocated variables is set to zero (so it compresses very well), # The memory bytes of the allocated variables is set to zero (so it compresses very well),
# their actual starting values are set by the block init code. # their actual starting values are set by the block init code.
vars_by_vartype = defaultdict(list) # type: Dict[VarType, List[VarDef]] vars_by_vartype = defaultdict(list) # type: Dict[VarType, List[VarDef]]
for vardef in block.scope.filter_nodes(VarDef): for vardef in block.all_nodes(VarDef):
vars_by_vartype[vardef.vartype].append(vardef) vars_by_vartype[vardef.vartype].append(vardef)
out("; constants") out("; constants")
for vardef in vars_by_vartype.get(VarType.CONST, []): for vardef in vars_by_vartype.get(VarType.CONST, []):
@ -132,13 +132,13 @@ def generate_block_vars(out: Callable, block: Block, zeropage: bool=False) -> No
# create a definition for variables at a specific place in memory (memory-mapped) # create a definition for variables at a specific place in memory (memory-mapped)
if vardef.datatype.isnumeric(): if vardef.datatype.isnumeric():
assert vardef.size == [1] assert vardef.size == [1]
out("\v{:s} = {:s}\t; {:s}".format(vardef.name, to_hex(vardef.value), vardef.datatype.name.lower())) out("\v{:s} = {:s}\t; {:s}".format(vardef.name, to_hex(vardef.value.value), vardef.datatype.name.lower()))
elif vardef.datatype == DataType.BYTEARRAY: elif vardef.datatype == DataType.BYTEARRAY:
assert len(vardef.size) == 1 assert len(vardef.size) == 1
out("\v{:s} = {:s}\t; array of {:d} bytes".format(vardef.name, to_hex(vardef.value), vardef.size[0])) out("\v{:s} = {:s}\t; array of {:d} bytes".format(vardef.name, to_hex(vardef.value.value), vardef.size[0]))
elif vardef.datatype == DataType.WORDARRAY: elif vardef.datatype == DataType.WORDARRAY:
assert len(vardef.size) == 1 assert len(vardef.size) == 1
out("\v{:s} = {:s}\t; array of {:d} words".format(vardef.name, to_hex(vardef.value), vardef.size[0])) out("\v{:s} = {:s}\t; array of {:d} words".format(vardef.name, to_hex(vardef.value.value), vardef.size[0]))
elif vardef.datatype == DataType.MATRIX: elif vardef.datatype == DataType.MATRIX:
assert len(vardef.size) in (2, 3) assert len(vardef.size) in (2, 3)
if len(vardef.size) == 2: if len(vardef.size) == 2:
@ -147,7 +147,7 @@ def generate_block_vars(out: Callable, block: Block, zeropage: bool=False) -> No
comment = "matrix of {:d} by {:d}, interleave {:d}".format(vardef.size[0], vardef.size[1], vardef.size[2]) comment = "matrix of {:d} by {:d}, interleave {:d}".format(vardef.size[0], vardef.size[1], vardef.size[2])
else: else:
raise CodeError("matrix size should be 2 or 3 numbers") raise CodeError("matrix size should be 2 or 3 numbers")
out("\v{:s} = {:s}\t; {:s}".format(vardef.name, to_hex(vardef.value), comment)) out("\v{:s} = {:s}\t; {:s}".format(vardef.name, to_hex(vardef.value.value), comment))
else: else:
raise CodeError("invalid var type") raise CodeError("invalid var type")
out("; normal variables - initial values will be set by init code") out("; normal variables - initial values will be set by init code")

View File

@ -81,7 +81,6 @@ def main() -> None:
print("\nParsing program source code.") print("\nParsing program source code.")
parser = PlyParser() parser = PlyParser()
parsed_module = parser.parse_file(args.sourcefile) parsed_module = parser.parse_file(args.sourcefile)
raise SystemExit("First fix the parser to iterate all nodes in the new way.") # XXX
if parsed_module: if parsed_module:
if args.nooptimize: if args.nooptimize:
print_bold("not optimizing the parse tree!") print_bold("not optimizing the parse tree!")

View File

@ -6,7 +6,7 @@ Written by Irmen de Jong (irmen@razorvine.net) - license: GNU GPL 3.0
""" """
from .plyparse import Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr,\ from .plyparse import Module, Subroutine, Block, Directive, Assignment, AugAssignment, Goto, Expression, IncrDecr,\
datatype_of, coerce_constant_value, AssignmentTargets datatype_of, coerce_constant_value, AssignmentTargets, LiteralValue
from .plylex import print_warning, print_bold from .plylex import print_warning, print_bold
@ -18,49 +18,51 @@ class Optimizer:
def optimize(self) -> None: def optimize(self) -> None:
self.num_warnings = 0 self.num_warnings = 0
self.optimize_assignments() self.optimize_assignments()
self.combine_assignments_into_multi() return # XXX fix all methods below
self.optimize_multiassigns() #self.combine_assignments_into_multi()
self.remove_unused_subroutines() #self.optimize_multiassigns()
self.optimize_compare_with_zero() #self.remove_unused_subroutines()
#self.optimize_compare_with_zero()
# @todo join multiple incr/decr of same var into one (if value stays < 256) # @todo join multiple incr/decr of same var into one (if value stays < 256)
# @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to) # @todo analyse for unreachable code and remove that (f.i. code after goto or return that has no label so can never be jumped to)
self.remove_empty_blocks() #self.remove_empty_blocks()
def optimize_assignments(self): def optimize_assignments(self) -> None:
# remove assignment statements that do nothing (A=A) # remove assignment statements that do nothing (A=A)
# and augmented assignments that have no effect (A+=0) # and augmented assignments that have no effect (A+=0)
# convert augmented assignments to simple incr/decr if possible (A+=10 => A++ by 10) # convert augmented assignments to simple incr/decr if possible (A+=10 => A++ by 10)
# @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?) # @todo remove or simplify logical aug assigns like A |= 0, A |= true, A |= false (or perhaps turn them into byte values first?)
for block, parent in self.module.all_scopes(): for assignment in self.module.all_nodes():
for assignment in list(block.nodes): if isinstance(assignment, Assignment):
if isinstance(assignment, Assignment): if any(lv != assignment.right for lv in assignment.left.nodes):
assignment.left = [lv for lv in assignment.left if lv != assignment.right] assignment.left.nodes = [lv for lv in assignment.left.nodes if lv != assignment.right]
if not assignment.left: if not assignment.left:
block.scope.remove_node(assignment) assignment.my_scope().remove_node(assignment)
self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
if isinstance(assignment, AugAssignment):
if isinstance(assignment.right, LiteralValue) and isinstance(assignment.right.value, (int, float)):
if assignment.right.value == 0 and assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="):
self.num_warnings += 1 self.num_warnings += 1
print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) print_warning("{}: removed statement that has no effect".format(assignment.sourceref))
if isinstance(assignment, AugAssignment): assignment.my_scope().remove_node(assignment)
if isinstance(assignment.right, (int, float)): if assignment.right.value >= 8 and assignment.operator in ("<<=", ">>="):
if assignment.right == 0 and assignment.operator in ("+=", "-=", "|=", "<<=", ">>=", "^="): print("{}: shifting result is always zero".format(assignment.sourceref))
self.num_warnings += 1 new_stmt = Assignment(sourceref=assignment.sourceref)
print_warning("{}: removed statement that has no effect".format(assignment.sourceref)) new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref))
block.scope.remove_node(assignment) new_stmt.nodes.append(0)
if assignment.right >= 8 and assignment.operator in ("<<=", ">>="): assignment.my_scope().replace_node(assignment, new_stmt)
print("{}: shifting result is always zero".format(assignment.sourceref)) if assignment.operator in ("+=", "-=") and 0 < assignment.right.value < 256:
new_stmt = Assignment(sourceref=assignment.sourceref) howmuch = assignment.right
new_stmt.nodes.append(AssignmentTargets(nodes=[assignment.left], sourceref=assignment.sourceref)) if howmuch.value not in (0, 1):
new_stmt.nodes.append(0) _, howmuch = coerce_constant_value(datatype_of(assignment.left, assignment.my_scope()),
block.scope.replace_node(assignment, new_stmt) howmuch, assignment.sourceref)
if assignment.operator in ("+=", "-=") and 0 < assignment.right < 256: new_stmt = IncrDecr(operator="++" if assignment.operator == "+=" else "--",
howmuch = assignment.right howmuch=howmuch.value, sourceref=assignment.sourceref)
if howmuch not in (0, 1): new_stmt.target = assignment.left
_, howmuch = coerce_constant_value(datatype_of(assignment.left, block.scope), howmuch, assignment.sourceref) assignment.my_scope().replace_node(assignment, new_stmt)
new_stmt = IncrDecr(operator="++" if assignment.operator == "+=" else "--",
howmuch=howmuch, sourceref=assignment.sourceref)
new_stmt.target = assignment.left
block.scope.replace_node(assignment, new_stmt)
def combine_assignments_into_multi(self): def combine_assignments_into_multi(self) -> None:
# fold multiple consecutive assignments with the same rvalue into one multi-assignment # fold multiple consecutive assignments with the same rvalue into one multi-assignment
for block, parent in self.module.all_scopes(): for block, parent in self.module.all_scopes():
rvalue = None rvalue = None
@ -86,7 +88,7 @@ class Optimizer:
rvalue = None rvalue = None
assignments.clear() assignments.clear()
def optimize_multiassigns(self): def optimize_multiassigns(self) -> None:
# optimize multi-assign statements (remove duplicate targets, optimize order) # optimize multi-assign statements (remove duplicate targets, optimize order)
for block, parent in self.module.all_scopes(): for block, parent in self.module.all_scopes():
for assignment in block.nodes: for assignment in block.nodes:
@ -98,7 +100,7 @@ class Optimizer:
# @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any) # @todo change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any)
assignment.left = list(lvalues) assignment.left = list(lvalues)
def remove_unused_subroutines(self): def remove_unused_subroutines(self) -> None:
# some symbols are used by the emitted assembly code from the code generator, # some symbols are used by the emitted assembly code from the code generator,
# and should never be removed or the assembler will fail # and should never be removed or the assembler will fail
never_remove = {"c64.FREADUY", "c64.FTOMEMXY", "c64.FADD", "c64.FSUB", never_remove = {"c64.FREADUY", "c64.FTOMEMXY", "c64.FADD", "c64.FSUB",
@ -114,14 +116,14 @@ class Optimizer:
if num_discarded: if num_discarded:
print("discarded {:d} unused subroutines".format(num_discarded)) print("discarded {:d} unused subroutines".format(num_discarded))
def optimize_compare_with_zero(self): def optimize_compare_with_zero(self) -> None:
# a conditional goto that compares a value with zero will be simplified # a conditional goto that compares a value with zero will be simplified
# the comparison operator and rvalue (0) will be removed and the if-status changed accordingly # the comparison operator and rvalue (0) will be removed and the if-status changed accordingly
for block, parent in self.module.all_scopes(): for block, parent in self.module.all_scopes():
if block.scope: if block.scope:
for stmt in block.scope.filter_nodes(Goto): for goto in block.all_nodes(Goto):
if isinstance(stmt.condition, Expression): if isinstance(goto.condition, Expression):
print("NOT IMPLEMENTED YET: optimize goto conditionals", stmt.condition) # @todo print("NOT IMPLEMENTED YET: optimize goto conditionals", goto.condition) # @todo
# if cond and isinstance(cond.rvalue, (int, float)) and cond.rvalue.value == 0: # if cond and isinstance(cond.rvalue, (int, float)) and cond.rvalue.value == 0:
# simplified = False # simplified = False
# if cond.ifstatus in ("true", "ne"): # if cond.ifstatus in ("true", "ne"):
@ -149,7 +151,7 @@ class Optimizer:
def remove_empty_blocks(self) -> None: def remove_empty_blocks(self) -> None:
# remove blocks without name and without address, or that are empty # remove blocks without name and without address, or that are empty
for node, parent in self.module.all_scopes(): for node in self.module.all_nodes():
if isinstance(node, (Subroutine, Block)): if isinstance(node, (Subroutine, Block)):
if not node.scope: if not node.scope:
continue continue
@ -160,14 +162,14 @@ class Optimizer:
if empty: if empty:
self.num_warnings += 1 self.num_warnings += 1
print_warning("ignoring empty block or subroutine", node.sourceref) print_warning("ignoring empty block or subroutine", node.sourceref)
assert isinstance(parent, (Block, Module)) assert isinstance(node.parent, (Block, Module))
parent.scope.nodes.remove(node) node.my_scope().nodes.remove(node)
if isinstance(node, Block): if isinstance(node, Block):
if not node.name and node.address is None: if not node.name and node.address is None:
self.num_warnings += 1 self.num_warnings += 1
print_warning("ignoring block without name and address", node.sourceref) print_warning("ignoring block without name and address", node.sourceref)
assert isinstance(parent, Module) assert isinstance(node.parent, Module)
parent.scope.nodes.remove(node) node.my_scope().nodes.remove(node)
def optimize(mod: Module) -> None: def optimize(mod: Module) -> None:

View File

@ -270,7 +270,7 @@ class Module(AstNode):
@no_type_check @no_type_check
def zeropage(self) -> Optional[Block]: def zeropage(self) -> Optional[Block]:
# return the zeropage block (if defined) # return the zeropage block (if defined)
first_block = next(self.scope.filter_nodes(Block)) first_block = next(self.scope.all_nodes(Block))
if first_block.name == "ZP": if first_block.name == "ZP":
return first_block return first_block
return None return None
@ -278,7 +278,7 @@ class Module(AstNode):
@no_type_check @no_type_check
def main(self) -> Optional[Block]: def main(self) -> Optional[Block]:
# return the 'main' block (if defined) # return the 'main' block (if defined)
for block in self.scope.filter_nodes(Block): for block in self.scope.all_nodes(Block):
if block.name == "main": if block.name == "main":
return block return block
return None return None