mirror of
https://github.com/irmen/prog8.git
synced 2025-02-19 11:31:07 +00:00
optimization for subs
This commit is contained in:
parent
b4d82ba8e6
commit
bca33f8765
@ -239,7 +239,7 @@ class CodeGenerator:
|
|||||||
self.p("; end external subroutines")
|
self.p("; end external subroutines")
|
||||||
for stmt in block.statements:
|
for stmt in block.statements:
|
||||||
self.generate_statement(stmt)
|
self.generate_statement(stmt)
|
||||||
subroutines = list(sub for sub in block.symbols.iter_subroutines() if sub.address is None)
|
subroutines = list(sub for sub in block.symbols.iter_subroutines(True))
|
||||||
if subroutines:
|
if subroutines:
|
||||||
self.p("\n; block subroutines")
|
self.p("\n; block subroutines")
|
||||||
for subdef in subroutines:
|
for subdef in subroutines:
|
||||||
@ -532,8 +532,6 @@ class CodeGenerator:
|
|||||||
self.p("\t\tjsr " + targetstr)
|
self.p("\t\tjsr " + targetstr)
|
||||||
|
|
||||||
def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None:
|
def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None:
|
||||||
rvalue = stmt.right
|
|
||||||
|
|
||||||
def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue:
|
def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue:
|
||||||
if isinstance(iv.value, ParseResult.MemMappedValue):
|
if isinstance(iv.value, ParseResult.MemMappedValue):
|
||||||
return iv.value
|
return iv.value
|
||||||
@ -542,6 +540,7 @@ class CodeGenerator:
|
|||||||
else:
|
else:
|
||||||
raise CodeError("cannot yet generate code for assignment: non-constant and non-memmapped indirect") # XXX
|
raise CodeError("cannot yet generate code for assignment: non-constant and non-memmapped indirect") # XXX
|
||||||
|
|
||||||
|
rvalue = stmt.right
|
||||||
if isinstance(rvalue, ParseResult.IndirectValue):
|
if isinstance(rvalue, ParseResult.IndirectValue):
|
||||||
rvalue = unwrap_indirect(rvalue)
|
rvalue = unwrap_indirect(rvalue)
|
||||||
self.p("\t\t\t\t\t; src l. {:d}".format(stmt.lineno))
|
self.p("\t\t\t\t\t; src l. {:d}".format(stmt.lineno))
|
||||||
|
120
il65/parse.py
120
il65/parse.py
@ -379,13 +379,16 @@ class ParseResult:
|
|||||||
self.right.name = stringvar_name
|
self.right.name = stringvar_name
|
||||||
self._immediate_string_vars[self.right.value] = (cur_block.name, stringvar_name)
|
self._immediate_string_vars[self.right.value] = (cur_block.name, stringvar_name)
|
||||||
|
|
||||||
def remove_identity_assigns(self) -> None:
|
def remove_identity_lvalues(self, filename: str, lineno: int) -> None:
|
||||||
for lv in self.leftvalues:
|
for lv in self.leftvalues:
|
||||||
if lv == self.right:
|
if lv == self.right:
|
||||||
print("warning: {:d}: removed identity assignment".format(self.lineno))
|
print("{:s}:{:d}: removed identity assignment".format(filename, lineno))
|
||||||
remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right]
|
remaining_leftvalues = [lv for lv in self.leftvalues if lv != self.right]
|
||||||
self.leftvalues = remaining_leftvalues
|
self.leftvalues = remaining_leftvalues
|
||||||
|
|
||||||
|
def is_identity(self) -> bool:
|
||||||
|
return all(lv == self.right for lv in self.leftvalues)
|
||||||
|
|
||||||
class ReturnStmt(_AstNode):
|
class ReturnStmt(_AstNode):
|
||||||
def __init__(self, a: Optional['ParseResult.Value']=None,
|
def __init__(self, a: Optional['ParseResult.Value']=None,
|
||||||
x: Optional['ParseResult.Value']=None,
|
x: Optional['ParseResult.Value']=None,
|
||||||
@ -417,8 +420,9 @@ class ParseResult:
|
|||||||
for name, value in self.arguments:
|
for name, value in self.arguments:
|
||||||
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
|
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
|
||||||
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
|
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
|
||||||
assignment.lineno = self.lineno
|
if not assignment.is_identity():
|
||||||
self.desugared_call_arguments.append(assignment)
|
assignment.lineno = self.lineno
|
||||||
|
self.desugared_call_arguments.append(assignment)
|
||||||
|
|
||||||
class InlineAsm(_AstNode):
|
class InlineAsm(_AstNode):
|
||||||
def __init__(self, lineno: int, asmlines: List[str]) -> None:
|
def __init__(self, lineno: int, asmlines: List[str]) -> None:
|
||||||
@ -453,7 +457,7 @@ class Parser:
|
|||||||
self.lines = self.load_source(filename)
|
self.lines = self.load_source(filename)
|
||||||
self.outputdir = outputdir
|
self.outputdir = outputdir
|
||||||
self.parsing_import = parsing_import # are we parsing a import file?
|
self.parsing_import = parsing_import # are we parsing a import file?
|
||||||
self.cur_lineidx = -1
|
self._cur_lineidx = -1 # used to efficiently go to next/previous line in source
|
||||||
self.cur_block = None # type: ParseResult.Block
|
self.cur_block = None # type: ParseResult.Block
|
||||||
self.root_scope = SymbolTable("<root>", None, None)
|
self.root_scope = SymbolTable("<root>", None, None)
|
||||||
self.ppsymbols = ppsymbols # symboltable from preprocess phase
|
self.ppsymbols = ppsymbols # symboltable from preprocess phase
|
||||||
@ -536,50 +540,62 @@ class Parser:
|
|||||||
raise self.PError("invalid statement or characters, block expected")
|
raise self.PError("invalid statement or characters, block expected")
|
||||||
if not self.parsing_import:
|
if not self.parsing_import:
|
||||||
# check if we have a proper main block to contain the program's entry point
|
# check if we have a proper main block to contain the program's entry point
|
||||||
|
main_found = False
|
||||||
for block in self.result.blocks:
|
for block in self.result.blocks:
|
||||||
if block.name == "main":
|
if block.name == "main":
|
||||||
|
main_found = True
|
||||||
if "start" not in block.label_names:
|
if "start" not in block.label_names:
|
||||||
self.sourceref.line = block.sourceref.line
|
self.sourceref.line = block.sourceref.line
|
||||||
self.sourceref.column = 0
|
self.sourceref.column = 0
|
||||||
raise self.PError("The 'main' block should contain the program entry point 'start'")
|
raise self.PError("The 'main' block should contain the program entry point 'start'")
|
||||||
if not any(s for s in block.statements if isinstance(s, ParseResult.ReturnStmt)):
|
self._check_return_statement(block, "'main' block")
|
||||||
self.print_warning("warning: {}: The 'main' block is lacking a return statement.".format(block.sourceref))
|
for sub in block.symbols.iter_subroutines(True):
|
||||||
break
|
self._check_return_statement(sub.sub_block, "'{:s}' subroutine".format(sub.name))
|
||||||
else:
|
if not main_found:
|
||||||
raise self.PError("A block named 'main' should be defined for the program's entry point 'start'")
|
raise self.PError("A block named 'main' should be defined for the program's entry point 'start'")
|
||||||
|
|
||||||
|
def _check_return_statement(self, block: ParseResult.Block, message: str) -> None:
|
||||||
|
# find last statement that isn't a comment
|
||||||
|
for stmt in reversed(block.statements):
|
||||||
|
if isinstance(stmt, ParseResult.Comment):
|
||||||
|
continue
|
||||||
|
if isinstance(stmt, ParseResult.ReturnStmt) or isinstance(stmt, ParseResult.CallStmt) and stmt.is_goto:
|
||||||
|
return
|
||||||
|
if isinstance(stmt, ParseResult.InlineAsm):
|
||||||
|
# check that the last asm line is a jmp or a rts
|
||||||
|
for asmline in reversed(stmt.asmlines):
|
||||||
|
if asmline.lstrip().startswith(';'):
|
||||||
|
continue
|
||||||
|
if " rts" in asmline or "\trts" in asmline or " jmp" in asmline or "\tjmp" in asmline:
|
||||||
|
return
|
||||||
|
if asmline.strip():
|
||||||
|
if asmline.split()[0].isidentifier():
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
break
|
||||||
|
self.print_warning("warning: {}: The {:s} doesn't end with a return statement".format(block.sourceref, message))
|
||||||
|
|
||||||
def _parse_2(self) -> None:
|
def _parse_2(self) -> None:
|
||||||
# parsing pass 2 (not done during preprocessing!)
|
# parsing pass 2 (not done during preprocessing!)
|
||||||
self.cur_block = None
|
self.cur_block = None
|
||||||
self.sourceref.line = -1
|
self.sourceref.line = -1
|
||||||
self.sourceref.column = 0
|
self.sourceref.column = 0
|
||||||
|
|
||||||
for block in self.result.blocks:
|
for block in self.result.blocks:
|
||||||
self.cur_block = block
|
self.cur_block = block
|
||||||
# remove identity assignments
|
|
||||||
have_removed_stmts = False
|
|
||||||
for index, stmt in enumerate(list(block.statements)):
|
|
||||||
if isinstance(stmt, ParseResult.AssignmentStmt):
|
|
||||||
stmt.remove_identity_assigns()
|
|
||||||
if not stmt.leftvalues:
|
|
||||||
print("warning: {:s}:{:d}: removed identity assignment statement".format(self.sourceref.file, stmt.lineno))
|
|
||||||
have_removed_stmts = True
|
|
||||||
block.statements[index] = None
|
|
||||||
if have_removed_stmts:
|
|
||||||
# remove the Nones
|
|
||||||
block.statements = [s for s in block.statements if s is not None]
|
|
||||||
# create parameter loads for calls
|
# create parameter loads for calls
|
||||||
for index, stmt in enumerate(list(block.statements)):
|
for index, stmt in enumerate(list(block.statements)):
|
||||||
if isinstance(stmt, ParseResult.CallStmt):
|
if isinstance(stmt, ParseResult.CallStmt):
|
||||||
self.sourceref.line = stmt.lineno
|
self.sourceref.line = stmt.lineno
|
||||||
self.sourceref.column = 0
|
self.sourceref.column = 0
|
||||||
stmt.desugar_call_arguments(self)
|
stmt.desugar_call_arguments(self)
|
||||||
for sub in block.symbols.iter_subroutines():
|
# create parameter loads for calls, in subroutine blocks
|
||||||
if sub.address is None and sub.sub_block:
|
for sub in block.symbols.iter_subroutines(True):
|
||||||
for stmt in sub.sub_block.statements:
|
for stmt in sub.sub_block.statements:
|
||||||
if isinstance(stmt, ParseResult.CallStmt):
|
if isinstance(stmt, ParseResult.CallStmt):
|
||||||
self.sourceref.line = stmt.lineno
|
self.sourceref.line = stmt.lineno
|
||||||
self.sourceref.column = 0
|
self.sourceref.column = 0
|
||||||
stmt.desugar_call_arguments(self)
|
stmt.desugar_call_arguments(self)
|
||||||
block.flatten_statement_list()
|
block.flatten_statement_list()
|
||||||
# desugar immediate string value assignments
|
# desugar immediate string value assignments
|
||||||
for index, stmt in enumerate(list(block.statements)):
|
for index, stmt in enumerate(list(block.statements)):
|
||||||
@ -594,38 +610,32 @@ class Parser:
|
|||||||
stmt.desugar_immediate_string(self)
|
stmt.desugar_immediate_string(self)
|
||||||
|
|
||||||
def next_line(self) -> str:
|
def next_line(self) -> str:
|
||||||
self.cur_lineidx += 1
|
self._cur_lineidx += 1
|
||||||
try:
|
try:
|
||||||
self.sourceref.line, line = self.lines[self.cur_lineidx]
|
self.sourceref.line, line = self.lines[self._cur_lineidx]
|
||||||
self.sourceref.column = 0
|
self.sourceref.column = 0
|
||||||
return line
|
return line
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def prev_line(self) -> str:
|
def prev_line(self) -> str:
|
||||||
self.cur_lineidx -= 1
|
self._cur_lineidx -= 1
|
||||||
self.sourceref.line, line = self.lines[self.cur_lineidx]
|
self.sourceref.line, line = self.lines[self._cur_lineidx]
|
||||||
return line
|
return line
|
||||||
|
|
||||||
def peek_next_line(self) -> str:
|
def peek_next_line(self) -> str:
|
||||||
if (self.cur_lineidx + 1) < len(self.lines):
|
if (self._cur_lineidx + 1) < len(self.lines):
|
||||||
return self.lines[self.cur_lineidx + 1][1]
|
return self.lines[self._cur_lineidx + 1][1]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def PError(self, message: str, lineno: int=0, column: int=0) -> ParseError:
|
def PError(self, message: str, lineno: int=0, column: int=0) -> ParseError:
|
||||||
sourceline = ""
|
sourceline = ""
|
||||||
if lineno:
|
lineno = lineno or self.sourceref.line
|
||||||
for num, text in self.lines:
|
column = column or self.sourceref.column
|
||||||
if num == lineno:
|
for num, text in self.lines:
|
||||||
sourceline = text.strip()
|
if num == lineno:
|
||||||
break
|
sourceline = text.strip()
|
||||||
else:
|
break
|
||||||
lineno = self.sourceref.line
|
|
||||||
column = self.sourceref.column
|
|
||||||
self.cur_lineidx = min(self.cur_lineidx, len(self.lines) - 1)
|
|
||||||
if self.cur_lineidx:
|
|
||||||
sourceline = self.lines[self.cur_lineidx][1].strip()
|
|
||||||
# XXX source line is wrong when dealing with errors in sub call
|
|
||||||
return ParseError(message, sourceline, SourceRef(self.sourceref.file, lineno, column))
|
return ParseError(message, sourceline, SourceRef(self.sourceref.file, lineno, column))
|
||||||
|
|
||||||
def get_datatype(self, typestr: str) -> Tuple[DataType, int, Optional[Tuple[int, int]]]:
|
def get_datatype(self, typestr: str) -> Tuple[DataType, int, Optional[Tuple[int, int]]]:
|
||||||
@ -1377,8 +1387,13 @@ class Optimizer:
|
|||||||
def optimize(self) -> ParseResult:
|
def optimize(self) -> ParseResult:
|
||||||
print("\noptimizing parse tree")
|
print("\noptimizing parse tree")
|
||||||
for block in self.parsed.blocks:
|
for block in self.parsed.blocks:
|
||||||
|
self.remove_identity_assigns(block)
|
||||||
self.combine_assignments_into_multi(block)
|
self.combine_assignments_into_multi(block)
|
||||||
self.optimize_multiassigns(block)
|
self.optimize_multiassigns(block)
|
||||||
|
for sub in block.symbols.iter_subroutines(True):
|
||||||
|
self.remove_identity_assigns(sub.sub_block)
|
||||||
|
self.combine_assignments_into_multi(sub.sub_block)
|
||||||
|
self.optimize_multiassigns(sub.sub_block)
|
||||||
return self.parsed
|
return self.parsed
|
||||||
|
|
||||||
def combine_assignments_into_multi(self, block: ParseResult.Block) -> None:
|
def combine_assignments_into_multi(self, block: ParseResult.Block) -> None:
|
||||||
@ -1414,6 +1429,19 @@ class Optimizer:
|
|||||||
# change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any)
|
# change order: first registers, then zp addresses, then non-zp addresses, then the rest (if any)
|
||||||
stmt.leftvalues = list(sorted(lvalues, key=_value_sortkey))
|
stmt.leftvalues = list(sorted(lvalues, key=_value_sortkey))
|
||||||
|
|
||||||
|
def remove_identity_assigns(self, block: ParseResult.Block) -> None:
|
||||||
|
have_removed_stmts = False
|
||||||
|
for index, stmt in enumerate(list(block.statements)):
|
||||||
|
if isinstance(stmt, ParseResult.AssignmentStmt):
|
||||||
|
stmt.remove_identity_lvalues(block.sourceref.file, stmt.lineno)
|
||||||
|
if not stmt.leftvalues:
|
||||||
|
print("{:s}:{:d}: removed identity assignment statement".format(block.sourceref.file, stmt.lineno))
|
||||||
|
have_removed_stmts = True
|
||||||
|
block.statements[index] = None
|
||||||
|
if have_removed_stmts:
|
||||||
|
# remove the Nones
|
||||||
|
block.statements = [s for s in block.statements if s is not None]
|
||||||
|
|
||||||
|
|
||||||
def _value_sortkey(value: ParseResult.Value) -> int:
|
def _value_sortkey(value: ParseResult.Value) -> int:
|
||||||
if isinstance(value, ParseResult.RegisterValue):
|
if isinstance(value, ParseResult.RegisterValue):
|
||||||
|
@ -303,8 +303,12 @@ class SymbolTable:
|
|||||||
def iter_constants(self) -> Iterable[ConstantDef]:
|
def iter_constants(self) -> Iterable[ConstantDef]:
|
||||||
yield from sorted((v for v in self.symbols.values() if isinstance(v, ConstantDef)))
|
yield from sorted((v for v in self.symbols.values() if isinstance(v, ConstantDef)))
|
||||||
|
|
||||||
def iter_subroutines(self) -> Iterable[SubroutineDef]:
|
def iter_subroutines(self, userdefined_only: bool=False) -> Iterable[SubroutineDef]:
|
||||||
yield from sorted((v for v in self.symbols.values() if isinstance(v, SubroutineDef)))
|
if userdefined_only:
|
||||||
|
yield from sorted((sub for sub in self.symbols.values()
|
||||||
|
if isinstance(sub, SubroutineDef) and sub.address is None and sub.sub_block is not None))
|
||||||
|
else:
|
||||||
|
yield from sorted((sub for sub in self.symbols.values() if isinstance(sub, SubroutineDef)))
|
||||||
|
|
||||||
def iter_labels(self) -> Iterable[LabelDef]:
|
def iter_labels(self) -> Iterable[LabelDef]:
|
||||||
yield from sorted((v for v in self.symbols.values() if isinstance(v, LabelDef)))
|
yield from sorted((v for v in self.symbols.values() if isinstance(v, LabelDef)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user