fix return value clobbering

This commit is contained in:
Irmen de Jong 2017-12-26 01:30:22 +01:00
parent 4a9d3200cd
commit be76d3321b
5 changed files with 124 additions and 49 deletions

View File

@ -77,7 +77,7 @@ def parse_arguments(text: str, sourceref: SourceRef) -> List[Tuple[str, Primitiv
if isinstance(node, ast.Str): if isinstance(node, ast.Str):
return repr(node.s) return repr(node.s)
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): # type: ignore
return '#' + astnode_to_repr(node.right) return '#' + astnode_to_repr(node.right)
else: else:
print("error", ast.dump(node)) print("error", ast.dump(node))

View File

@ -12,7 +12,7 @@ import datetime
import subprocess import subprocess
import contextlib import contextlib
from functools import partial from functools import partial
from typing import TextIO, Set, Union from typing import TextIO, Set, Union, List
from .parse import ProgramFormat, ParseResult, Parser from .parse import ProgramFormat, ParseResult, Parser
from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \ from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \
STRING_DATATYPES, REGISTER_WORDS, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE STRING_DATATYPES, REGISTER_WORDS, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE
@ -438,6 +438,10 @@ class CodeGenerator:
for assign_stmt in stmt.desugared_call_arguments: for assign_stmt in stmt.desugared_call_arguments:
self.generate_assignment(assign_stmt) self.generate_assignment(assign_stmt)
def generate_result_assignments() -> None:
for assign_stmt in stmt.desugared_output_assignments:
self.generate_assignment(assign_stmt)
def params_load_a() -> bool: def params_load_a() -> bool:
for assign_stmt in stmt.desugared_call_arguments: for assign_stmt in stmt.desugared_call_arguments:
for lv in assign_stmt.leftvalues: for lv in assign_stmt.leftvalues:
@ -446,6 +450,16 @@ class CodeGenerator:
return True return True
return False return False
def unclobber_result_registers(registers: Set[str], output_assignments: List[ParseResult.AssignmentStmt]) -> None:
for a in output_assignments:
for lv in a.leftvalues:
if isinstance(lv, ParseResult.RegisterValue):
if len(lv.register) == 1:
registers.remove(lv.register)
else:
for r in lv.register:
registers.remove(r)
if stmt.target.name: if stmt.target.name:
symblock, targetdef = self.cur_block.lookup(stmt.target.name) symblock, targetdef = self.cur_block.lookup(stmt.target.name)
else: else:
@ -459,14 +473,17 @@ class CodeGenerator:
if stmt.is_goto: if stmt.is_goto:
generate_param_assignments() generate_param_assignments()
self.p("\t\tjmp " + targetstr) self.p("\t\tjmp " + targetstr)
# no result assignments because it's a goto
return return
clobbered = set() # type: Set[str] clobbered = set() # type: Set[str]
if targetdef.clobbered_registers: if targetdef.clobbered_registers:
if stmt.preserve_regs: if stmt.preserve_regs:
clobbered = targetdef.clobbered_registers clobbered = targetdef.clobbered_registers
unclobber_result_registers(clobbered, stmt.desugared_output_assignments)
with self.preserving_registers(clobbered, loads_a_within=params_load_a()): with self.preserving_registers(clobbered, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
self.p("\t\tjsr " + targetstr) self.p("\t\tjsr " + targetstr)
generate_result_assignments()
return return
if isinstance(stmt.target, ParseResult.IndirectValue): if isinstance(stmt.target, ParseResult.IndirectValue):
if stmt.target.name: if stmt.target.name:
@ -490,8 +507,10 @@ class CodeGenerator:
self.p("\t\tjmp ({:s})".format(Parser.to_hex(Zeropage.SCRATCH_B1))) self.p("\t\tjmp ({:s})".format(Parser.to_hex(Zeropage.SCRATCH_B1)))
else: else:
self.p("\t\tjmp ({:s})".format(targetstr)) self.p("\t\tjmp ({:s})".format(targetstr))
# no result assignments because it's a goto
else: else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()): with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
if targetstr in REGISTER_WORDS: if targetstr in REGISTER_WORDS:
@ -519,6 +538,7 @@ class CodeGenerator:
self.p("\t\tjmp ++") self.p("\t\tjmp ++")
self.p("+\t\tjmp ({:s})".format(targetstr)) self.p("+\t\tjmp ({:s})".format(targetstr))
self.p("+") self.p("+")
generate_result_assignments()
else: else:
if stmt.target.name: if stmt.target.name:
targetstr = stmt.target.name targetstr = stmt.target.name
@ -532,11 +552,14 @@ class CodeGenerator:
# no need to preserve registers for a goto # no need to preserve registers for a goto
generate_param_assignments() generate_param_assignments()
self.p("\t\tjmp " + targetstr) self.p("\t\tjmp " + targetstr)
# no result assignments because it's a goto
else: else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()): with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
self.p("\t\tjsr " + targetstr) self.p("\t\tjsr " + targetstr)
generate_result_assignments()
def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None: def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None:
def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue: def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue:

View File

@ -9,10 +9,11 @@ License: GNU GPL 3.0, see LICENSE
import math import math
import re import re
import os import os
import sys
import shutil import shutil
import enum import enum
from collections import defaultdict from collections import defaultdict
from typing import Set, List, Tuple, Optional, Any, Dict, Union from typing import Set, List, Tuple, Optional, Any, Dict, Union, Sequence
from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\ from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\
parse_expr_as_string, parse_arguments parse_expr_as_string, parse_arguments
from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \ from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \
@ -409,24 +410,32 @@ class ParseResult:
class CallStmt(_AstNode): class CallStmt(_AstNode):
def __init__(self, lineno: int, target: Optional['ParseResult.Value']=None, *, def __init__(self, lineno: int, target: Optional['ParseResult.Value']=None, *,
address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None, address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None,
is_goto: bool=False, preserve_regs: bool=True) -> None: outputs: List[Tuple[str, 'ParseResult.Value']]=None, is_goto: bool=False, preserve_regs: bool=True) -> None:
self.lineno = lineno self.lineno = lineno
self.target = target self.target = target
self.address = address self.address = address
self.arguments = arguments self.arguments = arguments
self.outputvars = outputs
self.is_goto = is_goto self.is_goto = is_goto
self.preserve_regs = preserve_regs self.preserve_regs = preserve_regs
self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt] self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt]
self.desugared_output_assignments = [] # type: List[ParseResult.AssignmentStmt]
def desugar_call_arguments(self, parser: 'Parser') -> None: def desugar_call_arguments_and_outputs(self, parser: 'Parser') -> None:
if self.arguments: self.desugared_call_arguments.clear()
self.desugared_call_arguments.clear() self.desugared_output_assignments.clear()
for name, value in self.arguments: for name, value in self.arguments or []:
assert name is not None, "all call arguments should have a name or be matched on a named parameter" assert name is not None, "all call arguments should have a name or be matched on a named parameter"
assignment = parser.parse_assignment("{:s}={:s}".format(name, value)) assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
if not assignment.is_identity(): if not assignment.is_identity():
assignment.lineno = self.lineno assignment.lineno = self.lineno
self.desugared_call_arguments.append(assignment) self.desugared_call_arguments.append(assignment)
for register, value in self.outputvars or []:
rvalue = parser.parse_expression(register)
assignment = ParseResult.AssignmentStmt([value], rvalue, self.lineno)
# note: we need the identity assignment here or the output register handling generates buggy code
assignment.lineno = self.lineno
self.desugared_output_assignments.append(assignment)
class InlineAsm(_AstNode): class InlineAsm(_AstNode):
def __init__(self, lineno: int, asmlines: List[str]) -> None: def __init__(self, lineno: int, asmlines: List[str]) -> None:
@ -502,19 +511,27 @@ class Parser:
try: try:
return self.parse_file() return self.parse_file()
except ParseError as x: except ParseError as x:
print() if sys.stderr.isatty():
print("\x1b[1m", file=sys.stderr)
print("", file=sys.stderr)
if x.sourcetext: if x.sourcetext:
print("\tsource text: '{:s}'".format(x.sourcetext)) print("\tsource text: '{:s}'".format(x.sourcetext), file=sys.stderr)
if x.sourceref.column: if x.sourceref.column:
print("\t" + ' '*x.sourceref.column + ' ^') print("\t" + ' '*x.sourceref.column + ' ^', file=sys.stderr)
if self.parsing_import: if self.parsing_import:
print("Error (in imported file):", str(x)) print("Error (in imported file):", str(x), file=sys.stderr)
else: else:
print("Error:", str(x)) print("Error:", str(x), file=sys.stderr)
if sys.stderr.isatty():
print("\x1b[0m", file=sys.stderr)
raise # XXX temporary solution to get stack trace info in the event of parse errors raise # XXX temporary solution to get stack trace info in the event of parse errors
except Exception as x: except Exception as x:
print("\nERROR: internal parser error: ", x) if sys.stderr.isatty():
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line) print("\x1b[1m", file=sys.stderr)
print("\nERROR: internal parser error: ", x, file=sys.stderr)
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line, file=sys.stderr)
if sys.stderr.isatty():
print("\x1b[0m", file=sys.stderr)
raise # XXX temporary solution to get stack trace info in the event of parse errors raise # XXX temporary solution to get stack trace info in the event of parse errors
def parse_file(self) -> ParseResult: def parse_file(self) -> ParseResult:
@ -524,7 +541,10 @@ class Parser:
return self.result return self.result
def print_warning(self, text: str) -> None: def print_warning(self, text: str) -> None:
print(text) if sys.stdout.isatty():
print("\x1b[1m" + text + "\x1b[0m")
else:
print(text)
def _parse_comments(self) -> None: def _parse_comments(self) -> None:
while True: while True:
@ -604,22 +624,26 @@ class Parser:
if isinstance(stmt, ParseResult.CallStmt): if isinstance(stmt, ParseResult.CallStmt):
self.sourceref.line = stmt.lineno self.sourceref.line = stmt.lineno
self.sourceref.column = 0 self.sourceref.column = 0
stmt.desugar_call_arguments(self) stmt.desugar_call_arguments_and_outputs(self)
# create parameter loads for calls, in subroutine blocks # create parameter loads for calls, in subroutine blocks
for sub in block.symbols.iter_subroutines(True): for sub in block.symbols.iter_subroutines(True):
for stmt in sub.sub_block.statements: for stmt in sub.sub_block.statements:
if isinstance(stmt, ParseResult.CallStmt): if isinstance(stmt, ParseResult.CallStmt):
self.sourceref.line = stmt.lineno self.sourceref.line = stmt.lineno
self.sourceref.column = 0 self.sourceref.column = 0
stmt.desugar_call_arguments(self) stmt.desugar_call_arguments_and_outputs(self)
block.flatten_statement_list() block.flatten_statement_list()
# desugar immediate string value assignments # desugar immediate string value assignments
for index, stmt in enumerate(list(block.statements)): for index, stmt in enumerate(list(block.statements)):
if isinstance(stmt, ParseResult.CallStmt): if isinstance(stmt, ParseResult.CallStmt):
for stmt in stmt.desugared_call_arguments: for s in stmt.desugared_call_arguments:
self.sourceref.line = stmt.lineno self.sourceref.line = s.lineno
self.sourceref.column = 0 self.sourceref.column = 0
stmt.desugar_immediate_string(self) s.desugar_immediate_string(self)
for s in stmt.desugared_output_assignments:
self.sourceref.line = s.lineno
self.sourceref.column = 0
s.desugar_immediate_string(self)
if isinstance(stmt, ParseResult.AssignmentStmt): if isinstance(stmt, ParseResult.AssignmentStmt):
self.sourceref.line = stmt.lineno self.sourceref.line = stmt.lineno
self.sourceref.column = 0 self.sourceref.column = 0
@ -938,7 +962,7 @@ class Parser:
all_paramnames = [p[0] for p in parameters if p[0]] all_paramnames = [p[0] for p in parameters if p[0]]
if len(all_paramnames) != len(set(all_paramnames)): if len(all_paramnames) != len(set(all_paramnames)):
raise self.PError("duplicates in parameter names") raise self.PError("duplicates in parameter names")
results = {match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)} results = [match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)]
subroutine_block = None subroutine_block = None
if code_decl: if code_decl:
address = None address = None
@ -1016,17 +1040,31 @@ class Parser:
return varname, datatype, length, matrix_dimensions, valuetext return varname, datatype, length, matrix_dimensions, valuetext
def parse_statement(self, line: str) -> ParseResult._AstNode: def parse_statement(self, line: str) -> ParseResult._AstNode:
match = re.match(r"(?P<outputs>.*\s*=)\s*(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
if match:
# subroutine call (not a goto) with output param assignment
preserve = not bool(match.group("fcall"))
subname = match.group("subname")
arguments = match.group("arguments")
outputs = match.group("outputs")
if outputs.strip() == "=":
raise self.PError("missing assignment target variables")
outputs = outputs.rstrip("=")
if arguments or match.group(4):
return self.parse_call_or_goto(subname, arguments, outputs, preserve, False)
# apparently it is not a call (no arguments), fall through
match = re.match(r"(?P<goto>goto\s+)?(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line) match = re.match(r"(?P<goto>goto\s+)?(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
if match: if match:
# subroutine or goto call # subroutine or goto call, without output param assignment
is_goto = bool(match.group("goto")) is_goto = bool(match.group("goto"))
preserve = not bool(match.group("fcall")) preserve = not bool(match.group("fcall"))
subname = match.group("subname") subname = match.group("subname")
arguments = match.group("arguments") arguments = match.group("arguments")
if is_goto: if is_goto:
return self.parse_call_or_goto(subname, arguments, preserve, True) return self.parse_call_or_goto(subname, arguments, None, preserve, True)
elif arguments or match.group(4): elif arguments or match.group(4):
return self.parse_call_or_goto(subname, arguments, preserve, False) return self.parse_call_or_goto(subname, arguments, None, preserve, False)
# apparently it is not a call (no arguments), fall through
if line == "return" or line.startswith(("return ", "return\t")): if line == "return" or line.startswith(("return ", "return\t")):
return self.parse_return(line) return self.parse_return(line)
elif line.endswith(("++", "--")): elif line.endswith(("++", "--")):
@ -1042,9 +1080,12 @@ class Parser:
return self.parse_assignment(line) return self.parse_assignment(line)
raise self.PError("invalid statement") raise self.PError("invalid statement")
def parse_call_or_goto(self, targetstr: str, argumentstr: str, preserve_regs=True, is_goto=False) -> ParseResult.CallStmt: def parse_call_or_goto(self, targetstr: str, argumentstr: str, outputstr: str,
preserve_regs=True, is_goto=False) -> ParseResult.CallStmt:
argumentstr = argumentstr.strip() if argumentstr else "" argumentstr = argumentstr.strip() if argumentstr else ""
outputstr = outputstr.strip() if outputstr else ""
arguments = None arguments = None
outputvars = None
if argumentstr: if argumentstr:
arguments = parse_arguments(argumentstr, self.sourceref) arguments = parse_arguments(argumentstr, self.sourceref)
target = None # type: ParseResult.Value target = None # type: ParseResult.Value
@ -1080,19 +1121,35 @@ class Parser:
argname = preg argname = preg
args_with_pnames.append((argname, value)) args_with_pnames.append((argname, value))
arguments = args_with_pnames arguments = args_with_pnames
self.result.sub_used_by(symbol, self.sourceref) # verify output parameters
if symbol.return_registers:
if outputstr:
outputs = [r.strip() for r in outputstr.split(",")]
if len(outputs) != len(symbol.return_registers):
raise self.PError("invalid number of output parameters consumed ({:d}, expected {:d})"
.format(len(outputs), len(symbol.return_registers)))
outputvars = list(zip(symbol.return_registers, (self.parse_expression(out) for out in outputs)))
else:
self.print_warning("warning: {}: return values discarded".format(self.sourceref))
else:
if outputstr:
raise self.PError("this subroutine doesn't have output parameters")
self.result.sub_used_by(symbol, self.sourceref) # sub usage tracking
else: else:
if outputstr:
raise self.PError("call cannot use output parameter assignment here, a subroutine is required for that")
if arguments: if arguments:
raise self.PError("call cannot take any arguments here, use a subroutine for that") raise self.PError("call cannot take any arguments here, a subroutine is required for that")
if arguments: # verify that all arguments have gotten a name
# verify that all arguments have gotten a name if any(not a[0] for a in arguments or []):
if any(not a[0] for a in arguments): raise self.PError("all call arguments should have a name or be matched on a named parameter")
raise self.PError("all call arguments should have a name or be matched on a named parameter")
if isinstance(target, (type(None), ParseResult.Value)): if isinstance(target, (type(None), ParseResult.Value)):
if is_goto: if is_goto:
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, is_goto=True) return ParseResult.CallStmt(self.sourceref.line, target, address=address,
arguments=arguments, outputs=outputvars, is_goto=True)
else: else:
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, preserve_regs=preserve_regs) return ParseResult.CallStmt(self.sourceref.line, target, address=address,
arguments=arguments, outputs=outputvars, preserve_regs=preserve_regs)
else: else:
raise TypeError("target should be a Value", target) raise TypeError("target should be a Value", target)
@ -1216,7 +1273,6 @@ class Parser:
def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value: def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value:
# parse an expression into whatever it is (primitive value, register, memory, register, etc) # parse an expression into whatever it is (primitive value, register, memory, register, etc)
# @todo only numeric expressions supported for now
text = text.strip() text = text.strip()
if not text: if not text:
raise self.PError("value expected") raise self.PError("value expected")

View File

@ -167,29 +167,26 @@ class ConstantDef(SymbolDefinition):
class SubroutineDef(SymbolDefinition): class SubroutineDef(SymbolDefinition):
def __init__(self, blockname: str, name: str, sourceref: SourceRef, def __init__(self, blockname: str, name: str, sourceref: SourceRef,
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str], parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
address: Optional[int]=None, sub_block: Any=None) -> None: address: Optional[int]=None, sub_block: Any=None) -> None:
super().__init__(blockname, name, sourceref, False) super().__init__(blockname, name, sourceref, False)
self.address = address self.address = address
self.sub_block = sub_block # this is a ParseResult.Block self.sub_block = sub_block # this is a ParseResult.Block
self.parameters = parameters self.parameters = parameters
self.input_registers = set() # type: Set[str]
self.return_registers = set() # type: Set[str]
self.clobbered_registers = set() # type: Set[str] self.clobbered_registers = set() # type: Set[str]
self.return_registers = [] # type: List[str] # ordered!
for _, param in parameters: for _, param in parameters:
if param in REGISTER_BYTES: if param in REGISTER_BYTES:
self.input_registers.add(param)
self.clobbered_registers.add(param) self.clobbered_registers.add(param)
elif param in REGISTER_WORDS: elif param in REGISTER_WORDS:
self.input_registers.add(param[0])
self.input_registers.add(param[1])
self.clobbered_registers.add(param[0]) self.clobbered_registers.add(param[0])
self.clobbered_registers.add(param[1]) self.clobbered_registers.add(param[1])
else: else:
raise SymbolError("invalid parameter spec: " + param) raise SymbolError("invalid parameter spec: " + param)
for register in returnvalues: for register in returnvalues:
if register in REGISTER_SYMBOLS_RETURNVALUES: if register in REGISTER_SYMBOLS_RETURNVALUES:
self.return_registers.add(register) self.clobbered_registers.add(register)
self.return_registers.append(register)
elif register[-1] == "?": elif register[-1] == "?":
for r in register[:-1]: for r in register[:-1]:
if r not in REGISTER_SYMBOLS_RETURNVALUES: if r not in REGISTER_SYMBOLS_RETURNVALUES:
@ -387,7 +384,7 @@ class SymbolTable:
self.eval_dict = None self.eval_dict = None
def define_sub(self, name: str, sourceref: SourceRef, def define_sub(self, name: str, sourceref: SourceRef,
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str], parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
address: Optional[int], sub_block: Any) -> None: address: Optional[int], sub_block: Any) -> None:
self.check_identifier_valid(name, sourceref) self.check_identifier_valid(name, sourceref)
self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block) self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block)

View File

@ -3,7 +3,6 @@ output prg,sys
import "c64lib" import "c64lib"
~ main { ~ main {
var .text name = "?"*80 var .text name = "?"*80
start start