fix return value clobbering

This commit is contained in:
Irmen de Jong 2017-12-26 01:30:22 +01:00
parent 4a9d3200cd
commit be76d3321b
5 changed files with 124 additions and 49 deletions

View File

@ -77,7 +77,7 @@ def parse_arguments(text: str, sourceref: SourceRef) -> List[Tuple[str, Primitiv
if isinstance(node, ast.Str):
return repr(node.s)
if isinstance(node, ast.BinOp):
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult):
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): # type: ignore
return '#' + astnode_to_repr(node.right)
else:
print("error", ast.dump(node))

View File

@ -12,7 +12,7 @@ import datetime
import subprocess
import contextlib
from functools import partial
from typing import TextIO, Set, Union
from typing import TextIO, Set, Union, List
from .parse import ProgramFormat, ParseResult, Parser
from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \
STRING_DATATYPES, REGISTER_WORDS, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE
@ -438,6 +438,10 @@ class CodeGenerator:
for assign_stmt in stmt.desugared_call_arguments:
self.generate_assignment(assign_stmt)
def generate_result_assignments() -> None:
for assign_stmt in stmt.desugared_output_assignments:
self.generate_assignment(assign_stmt)
def params_load_a() -> bool:
for assign_stmt in stmt.desugared_call_arguments:
for lv in assign_stmt.leftvalues:
@ -446,6 +450,16 @@ class CodeGenerator:
return True
return False
def unclobber_result_registers(registers: Set[str], output_assignments: List[ParseResult.AssignmentStmt]) -> None:
for a in output_assignments:
for lv in a.leftvalues:
if isinstance(lv, ParseResult.RegisterValue):
if len(lv.register) == 1:
registers.remove(lv.register)
else:
for r in lv.register:
registers.remove(r)
if stmt.target.name:
symblock, targetdef = self.cur_block.lookup(stmt.target.name)
else:
@ -459,14 +473,17 @@ class CodeGenerator:
if stmt.is_goto:
generate_param_assignments()
self.p("\t\tjmp " + targetstr)
# no result assignments because it's a goto
return
clobbered = set() # type: Set[str]
if targetdef.clobbered_registers:
if stmt.preserve_regs:
clobbered = targetdef.clobbered_registers
unclobber_result_registers(clobbered, stmt.desugared_output_assignments)
with self.preserving_registers(clobbered, loads_a_within=params_load_a()):
generate_param_assignments()
self.p("\t\tjsr " + targetstr)
generate_result_assignments()
return
if isinstance(stmt.target, ParseResult.IndirectValue):
if stmt.target.name:
@ -490,8 +507,10 @@ class CodeGenerator:
self.p("\t\tjmp ({:s})".format(Parser.to_hex(Zeropage.SCRATCH_B1)))
else:
self.p("\t\tjmp ({:s})".format(targetstr))
# no result assignments because it's a goto
else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments()
if targetstr in REGISTER_WORDS:
@ -519,6 +538,7 @@ class CodeGenerator:
self.p("\t\tjmp ++")
self.p("+\t\tjmp ({:s})".format(targetstr))
self.p("+")
generate_result_assignments()
else:
if stmt.target.name:
targetstr = stmt.target.name
@ -532,11 +552,14 @@ class CodeGenerator:
# no need to preserve registers for a goto
generate_param_assignments()
self.p("\t\tjmp " + targetstr)
# no result assignments because it's a goto
else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments()
self.p("\t\tjsr " + targetstr)
generate_result_assignments()
def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None:
def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue:

View File

@ -9,10 +9,11 @@ License: GNU GPL 3.0, see LICENSE
import math
import re
import os
import sys
import shutil
import enum
from collections import defaultdict
from typing import Set, List, Tuple, Optional, Any, Dict, Union
from typing import Set, List, Tuple, Optional, Any, Dict, Union, Sequence
from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\
parse_expr_as_string, parse_arguments
from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \
@ -409,24 +410,32 @@ class ParseResult:
class CallStmt(_AstNode):
def __init__(self, lineno: int, target: Optional['ParseResult.Value']=None, *,
address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None,
is_goto: bool=False, preserve_regs: bool=True) -> None:
outputs: List[Tuple[str, 'ParseResult.Value']]=None, is_goto: bool=False, preserve_regs: bool=True) -> None:
self.lineno = lineno
self.target = target
self.address = address
self.arguments = arguments
self.outputvars = outputs
self.is_goto = is_goto
self.preserve_regs = preserve_regs
self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt]
self.desugared_output_assignments = [] # type: List[ParseResult.AssignmentStmt]
def desugar_call_arguments(self, parser: 'Parser') -> None:
if self.arguments:
self.desugared_call_arguments.clear()
for name, value in self.arguments:
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
if not assignment.is_identity():
assignment.lineno = self.lineno
self.desugared_call_arguments.append(assignment)
def desugar_call_arguments_and_outputs(self, parser: 'Parser') -> None:
self.desugared_call_arguments.clear()
self.desugared_output_assignments.clear()
for name, value in self.arguments or []:
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
if not assignment.is_identity():
assignment.lineno = self.lineno
self.desugared_call_arguments.append(assignment)
for register, value in self.outputvars or []:
rvalue = parser.parse_expression(register)
assignment = ParseResult.AssignmentStmt([value], rvalue, self.lineno)
# note: we need the identity assignment here or the output register handling generates buggy code
assignment.lineno = self.lineno
self.desugared_output_assignments.append(assignment)
class InlineAsm(_AstNode):
def __init__(self, lineno: int, asmlines: List[str]) -> None:
@ -502,19 +511,27 @@ class Parser:
try:
return self.parse_file()
except ParseError as x:
print()
if sys.stderr.isatty():
print("\x1b[1m", file=sys.stderr)
print("", file=sys.stderr)
if x.sourcetext:
print("\tsource text: '{:s}'".format(x.sourcetext))
print("\tsource text: '{:s}'".format(x.sourcetext), file=sys.stderr)
if x.sourceref.column:
print("\t" + ' '*x.sourceref.column + ' ^')
print("\t" + ' '*x.sourceref.column + ' ^', file=sys.stderr)
if self.parsing_import:
print("Error (in imported file):", str(x))
print("Error (in imported file):", str(x), file=sys.stderr)
else:
print("Error:", str(x))
print("Error:", str(x), file=sys.stderr)
if sys.stderr.isatty():
print("\x1b[0m", file=sys.stderr)
raise # XXX temporary solution to get stack trace info in the event of parse errors
except Exception as x:
print("\nERROR: internal parser error: ", x)
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line)
if sys.stderr.isatty():
print("\x1b[1m", file=sys.stderr)
print("\nERROR: internal parser error: ", x, file=sys.stderr)
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line, file=sys.stderr)
if sys.stderr.isatty():
print("\x1b[0m", file=sys.stderr)
raise # XXX temporary solution to get stack trace info in the event of parse errors
def parse_file(self) -> ParseResult:
@ -524,7 +541,10 @@ class Parser:
return self.result
def print_warning(self, text: str) -> None:
print(text)
if sys.stdout.isatty():
print("\x1b[1m" + text + "\x1b[0m")
else:
print(text)
def _parse_comments(self) -> None:
while True:
@ -604,22 +624,26 @@ class Parser:
if isinstance(stmt, ParseResult.CallStmt):
self.sourceref.line = stmt.lineno
self.sourceref.column = 0
stmt.desugar_call_arguments(self)
stmt.desugar_call_arguments_and_outputs(self)
# create parameter loads for calls, in subroutine blocks
for sub in block.symbols.iter_subroutines(True):
for stmt in sub.sub_block.statements:
if isinstance(stmt, ParseResult.CallStmt):
self.sourceref.line = stmt.lineno
self.sourceref.column = 0
stmt.desugar_call_arguments(self)
stmt.desugar_call_arguments_and_outputs(self)
block.flatten_statement_list()
# desugar immediate string value assignments
for index, stmt in enumerate(list(block.statements)):
if isinstance(stmt, ParseResult.CallStmt):
for stmt in stmt.desugared_call_arguments:
self.sourceref.line = stmt.lineno
for s in stmt.desugared_call_arguments:
self.sourceref.line = s.lineno
self.sourceref.column = 0
stmt.desugar_immediate_string(self)
s.desugar_immediate_string(self)
for s in stmt.desugared_output_assignments:
self.sourceref.line = s.lineno
self.sourceref.column = 0
s.desugar_immediate_string(self)
if isinstance(stmt, ParseResult.AssignmentStmt):
self.sourceref.line = stmt.lineno
self.sourceref.column = 0
@ -938,7 +962,7 @@ class Parser:
all_paramnames = [p[0] for p in parameters if p[0]]
if len(all_paramnames) != len(set(all_paramnames)):
raise self.PError("duplicates in parameter names")
results = {match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)}
results = [match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)]
subroutine_block = None
if code_decl:
address = None
@ -1016,17 +1040,31 @@ class Parser:
return varname, datatype, length, matrix_dimensions, valuetext
def parse_statement(self, line: str) -> ParseResult._AstNode:
match = re.match(r"(?P<outputs>.*\s*=)\s*(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
if match:
# subroutine call (not a goto) with output param assignment
preserve = not bool(match.group("fcall"))
subname = match.group("subname")
arguments = match.group("arguments")
outputs = match.group("outputs")
if outputs.strip() == "=":
raise self.PError("missing assignment target variables")
outputs = outputs.rstrip("=")
if arguments or match.group(4):
return self.parse_call_or_goto(subname, arguments, outputs, preserve, False)
# apparently it is not a call (no arguments), fall through
match = re.match(r"(?P<goto>goto\s+)?(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
if match:
# subroutine or goto call
# subroutine or goto call, without output param assignment
is_goto = bool(match.group("goto"))
preserve = not bool(match.group("fcall"))
subname = match.group("subname")
arguments = match.group("arguments")
if is_goto:
return self.parse_call_or_goto(subname, arguments, preserve, True)
return self.parse_call_or_goto(subname, arguments, None, preserve, True)
elif arguments or match.group(4):
return self.parse_call_or_goto(subname, arguments, preserve, False)
return self.parse_call_or_goto(subname, arguments, None, preserve, False)
# apparently it is not a call (no arguments), fall through
if line == "return" or line.startswith(("return ", "return\t")):
return self.parse_return(line)
elif line.endswith(("++", "--")):
@ -1042,9 +1080,12 @@ class Parser:
return self.parse_assignment(line)
raise self.PError("invalid statement")
def parse_call_or_goto(self, targetstr: str, argumentstr: str, preserve_regs=True, is_goto=False) -> ParseResult.CallStmt:
def parse_call_or_goto(self, targetstr: str, argumentstr: str, outputstr: str,
preserve_regs=True, is_goto=False) -> ParseResult.CallStmt:
argumentstr = argumentstr.strip() if argumentstr else ""
outputstr = outputstr.strip() if outputstr else ""
arguments = None
outputvars = None
if argumentstr:
arguments = parse_arguments(argumentstr, self.sourceref)
target = None # type: ParseResult.Value
@ -1080,19 +1121,35 @@ class Parser:
argname = preg
args_with_pnames.append((argname, value))
arguments = args_with_pnames
self.result.sub_used_by(symbol, self.sourceref)
# verify output parameters
if symbol.return_registers:
if outputstr:
outputs = [r.strip() for r in outputstr.split(",")]
if len(outputs) != len(symbol.return_registers):
raise self.PError("invalid number of output parameters consumed ({:d}, expected {:d})"
.format(len(outputs), len(symbol.return_registers)))
outputvars = list(zip(symbol.return_registers, (self.parse_expression(out) for out in outputs)))
else:
self.print_warning("warning: {}: return values discarded".format(self.sourceref))
else:
if outputstr:
raise self.PError("this subroutine doesn't have output parameters")
self.result.sub_used_by(symbol, self.sourceref) # sub usage tracking
else:
if outputstr:
raise self.PError("call cannot use output parameter assignment here, a subroutine is required for that")
if arguments:
raise self.PError("call cannot take any arguments here, use a subroutine for that")
if arguments:
# verify that all arguments have gotten a name
if any(not a[0] for a in arguments):
raise self.PError("all call arguments should have a name or be matched on a named parameter")
raise self.PError("call cannot take any arguments here, a subroutine is required for that")
# verify that all arguments have gotten a name
if any(not a[0] for a in arguments or []):
raise self.PError("all call arguments should have a name or be matched on a named parameter")
if isinstance(target, (type(None), ParseResult.Value)):
if is_goto:
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, is_goto=True)
return ParseResult.CallStmt(self.sourceref.line, target, address=address,
arguments=arguments, outputs=outputvars, is_goto=True)
else:
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, preserve_regs=preserve_regs)
return ParseResult.CallStmt(self.sourceref.line, target, address=address,
arguments=arguments, outputs=outputvars, preserve_regs=preserve_regs)
else:
raise TypeError("target should be a Value", target)
@ -1216,7 +1273,6 @@ class Parser:
def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value:
# parse an expression into whatever it is (primitive value, register, memory, register, etc)
# @todo only numeric expressions supported for now
text = text.strip()
if not text:
raise self.PError("value expected")

View File

@ -167,29 +167,26 @@ class ConstantDef(SymbolDefinition):
class SubroutineDef(SymbolDefinition):
def __init__(self, blockname: str, name: str, sourceref: SourceRef,
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str],
parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
address: Optional[int]=None, sub_block: Any=None) -> None:
super().__init__(blockname, name, sourceref, False)
self.address = address
self.sub_block = sub_block # this is a ParseResult.Block
self.parameters = parameters
self.input_registers = set() # type: Set[str]
self.return_registers = set() # type: Set[str]
self.clobbered_registers = set() # type: Set[str]
self.return_registers = [] # type: List[str] # ordered!
for _, param in parameters:
if param in REGISTER_BYTES:
self.input_registers.add(param)
self.clobbered_registers.add(param)
elif param in REGISTER_WORDS:
self.input_registers.add(param[0])
self.input_registers.add(param[1])
self.clobbered_registers.add(param[0])
self.clobbered_registers.add(param[1])
else:
raise SymbolError("invalid parameter spec: " + param)
for register in returnvalues:
if register in REGISTER_SYMBOLS_RETURNVALUES:
self.return_registers.add(register)
self.clobbered_registers.add(register)
self.return_registers.append(register)
elif register[-1] == "?":
for r in register[:-1]:
if r not in REGISTER_SYMBOLS_RETURNVALUES:
@ -387,7 +384,7 @@ class SymbolTable:
self.eval_dict = None
def define_sub(self, name: str, sourceref: SourceRef,
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str],
parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
address: Optional[int], sub_block: Any) -> None:
self.check_identifier_valid(name, sourceref)
self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block)

View File

@ -3,7 +3,6 @@ output prg,sys
import "c64lib"
~ main {
var .text name = "?"*80
start