From be76d3321b4347b3eed1f7ffa08a0f1ede99ba71 Mon Sep 17 00:00:00 2001 From: Irmen de Jong Date: Tue, 26 Dec 2017 01:30:22 +0100 Subject: [PATCH] fix return value clobbering --- il65/astparse.py | 2 +- il65/codegen.py | 25 +++++++- il65/parse.py | 132 ++++++++++++++++++++++++++++++------------- il65/symbols.py | 13 ++--- testsource/input.ill | 1 - 5 files changed, 124 insertions(+), 49 deletions(-) diff --git a/il65/astparse.py b/il65/astparse.py index 8a811647d..20c7c1fb9 100644 --- a/il65/astparse.py +++ b/il65/astparse.py @@ -77,7 +77,7 @@ def parse_arguments(text: str, sourceref: SourceRef) -> List[Tuple[str, Primitiv if isinstance(node, ast.Str): return repr(node.s) if isinstance(node, ast.BinOp): - if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): + if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): # type: ignore return '#' + astnode_to_repr(node.right) else: print("error", ast.dump(node)) diff --git a/il65/codegen.py b/il65/codegen.py index 6a58dd14b..17389f51c 100644 --- a/il65/codegen.py +++ b/il65/codegen.py @@ -12,7 +12,7 @@ import datetime import subprocess import contextlib from functools import partial -from typing import TextIO, Set, Union +from typing import TextIO, Set, Union, List from .parse import ProgramFormat, ParseResult, Parser from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \ STRING_DATATYPES, REGISTER_WORDS, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE @@ -438,6 +438,10 @@ class CodeGenerator: for assign_stmt in stmt.desugared_call_arguments: self.generate_assignment(assign_stmt) + def generate_result_assignments() -> None: + for assign_stmt in stmt.desugared_output_assignments: + self.generate_assignment(assign_stmt) + def params_load_a() -> bool: for assign_stmt in stmt.desugared_call_arguments: for lv in assign_stmt.leftvalues: @@ -446,6 +450,16 @@ class CodeGenerator: return True return False + def unclobber_result_registers(registers: Set[str], output_assignments: List[ParseResult.AssignmentStmt]) -> None: + for a in output_assignments: + for lv in a.leftvalues: + if isinstance(lv, ParseResult.RegisterValue): + if len(lv.register) == 1: + registers.remove(lv.register) + else: + for r in lv.register: + registers.remove(r) + if stmt.target.name: symblock, targetdef = self.cur_block.lookup(stmt.target.name) else: @@ -459,14 +473,17 @@ class CodeGenerator: if stmt.is_goto: generate_param_assignments() self.p("\t\tjmp " + targetstr) + # no result assignments because it's a goto return clobbered = set() # type: Set[str] if targetdef.clobbered_registers: if stmt.preserve_regs: clobbered = targetdef.clobbered_registers + unclobber_result_registers(clobbered, stmt.desugared_output_assignments) with self.preserving_registers(clobbered, loads_a_within=params_load_a()): generate_param_assignments() self.p("\t\tjsr " + targetstr) + generate_result_assignments() return if isinstance(stmt.target, ParseResult.IndirectValue): if stmt.target.name: @@ -490,8 +507,10 @@ class CodeGenerator: self.p("\t\tjmp ({:s})".format(Parser.to_hex(Zeropage.SCRATCH_B1))) else: self.p("\t\tjmp ({:s})".format(targetstr)) + # no result assignments because it's a goto else: preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() + unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments) with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()): generate_param_assignments() if targetstr in REGISTER_WORDS: @@ -519,6 +538,7 @@ class CodeGenerator: self.p("\t\tjmp ++") self.p("+\t\tjmp ({:s})".format(targetstr)) self.p("+") + generate_result_assignments() else: if stmt.target.name: targetstr = stmt.target.name @@ -532,11 +552,14 @@ class CodeGenerator: # no need to preserve registers for a goto generate_param_assignments() self.p("\t\tjmp " + targetstr) + # no result assignments because it's a goto else: preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() + unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments) with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()): generate_param_assignments() self.p("\t\tjsr " + targetstr) + generate_result_assignments() def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None: def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue: diff --git a/il65/parse.py b/il65/parse.py index 5e48003e5..38d579958 100644 --- a/il65/parse.py +++ b/il65/parse.py @@ -9,10 +9,11 @@ License: GNU GPL 3.0, see LICENSE import math import re import os +import sys import shutil import enum from collections import defaultdict -from typing import Set, List, Tuple, Optional, Any, Dict, Union +from typing import Set, List, Tuple, Optional, Any, Dict, Union, Sequence from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\ parse_expr_as_string, parse_arguments from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \ @@ -409,24 +410,32 @@ class ParseResult: class CallStmt(_AstNode): def __init__(self, lineno: int, target: Optional['ParseResult.Value']=None, *, address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None, - is_goto: bool=False, preserve_regs: bool=True) -> None: + outputs: List[Tuple[str, 'ParseResult.Value']]=None, is_goto: bool=False, preserve_regs: bool=True) -> None: self.lineno = lineno self.target = target self.address = address self.arguments = arguments + self.outputvars = outputs self.is_goto = is_goto self.preserve_regs = preserve_regs self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt] + self.desugared_output_assignments = [] # type: List[ParseResult.AssignmentStmt] - def desugar_call_arguments(self, parser: 'Parser') -> None: - if self.arguments: - self.desugared_call_arguments.clear() - for name, value in self.arguments: - assert name is not None, "all call arguments should have a name or be matched on a named parameter" - assignment = parser.parse_assignment("{:s}={:s}".format(name, value)) - if not assignment.is_identity(): - assignment.lineno = self.lineno - self.desugared_call_arguments.append(assignment) + def desugar_call_arguments_and_outputs(self, parser: 'Parser') -> None: + self.desugared_call_arguments.clear() + self.desugared_output_assignments.clear() + for name, value in self.arguments or []: + assert name is not None, "all call arguments should have a name or be matched on a named parameter" + assignment = parser.parse_assignment("{:s}={:s}".format(name, value)) + if not assignment.is_identity(): + assignment.lineno = self.lineno + self.desugared_call_arguments.append(assignment) + for register, value in self.outputvars or []: + rvalue = parser.parse_expression(register) + assignment = ParseResult.AssignmentStmt([value], rvalue, self.lineno) + # note: we need the identity assignment here or the output register handling generates buggy code + assignment.lineno = self.lineno + self.desugared_output_assignments.append(assignment) class InlineAsm(_AstNode): def __init__(self, lineno: int, asmlines: List[str]) -> None: @@ -502,19 +511,27 @@ class Parser: try: return self.parse_file() except ParseError as x: - print() + if sys.stderr.isatty(): + print("\x1b[1m", file=sys.stderr) + print("", file=sys.stderr) if x.sourcetext: - print("\tsource text: '{:s}'".format(x.sourcetext)) + print("\tsource text: '{:s}'".format(x.sourcetext), file=sys.stderr) if x.sourceref.column: - print("\t" + ' '*x.sourceref.column + ' ^') + print("\t" + ' '*x.sourceref.column + ' ^', file=sys.stderr) if self.parsing_import: - print("Error (in imported file):", str(x)) + print("Error (in imported file):", str(x), file=sys.stderr) else: - print("Error:", str(x)) + print("Error:", str(x), file=sys.stderr) + if sys.stderr.isatty(): + print("\x1b[0m", file=sys.stderr) raise # XXX temporary solution to get stack trace info in the event of parse errors except Exception as x: - print("\nERROR: internal parser error: ", x) - print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line) + if sys.stderr.isatty(): + print("\x1b[1m", file=sys.stderr) + print("\nERROR: internal parser error: ", x, file=sys.stderr) + print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line, file=sys.stderr) + if sys.stderr.isatty(): + print("\x1b[0m", file=sys.stderr) raise # XXX temporary solution to get stack trace info in the event of parse errors def parse_file(self) -> ParseResult: @@ -524,7 +541,10 @@ class Parser: return self.result def print_warning(self, text: str) -> None: - print(text) + if sys.stdout.isatty(): + print("\x1b[1m" + text + "\x1b[0m") + else: + print(text) def _parse_comments(self) -> None: while True: @@ -604,22 +624,26 @@ class Parser: if isinstance(stmt, ParseResult.CallStmt): self.sourceref.line = stmt.lineno self.sourceref.column = 0 - stmt.desugar_call_arguments(self) + stmt.desugar_call_arguments_and_outputs(self) # create parameter loads for calls, in subroutine blocks for sub in block.symbols.iter_subroutines(True): for stmt in sub.sub_block.statements: if isinstance(stmt, ParseResult.CallStmt): self.sourceref.line = stmt.lineno self.sourceref.column = 0 - stmt.desugar_call_arguments(self) + stmt.desugar_call_arguments_and_outputs(self) block.flatten_statement_list() # desugar immediate string value assignments for index, stmt in enumerate(list(block.statements)): if isinstance(stmt, ParseResult.CallStmt): - for stmt in stmt.desugared_call_arguments: - self.sourceref.line = stmt.lineno + for s in stmt.desugared_call_arguments: + self.sourceref.line = s.lineno self.sourceref.column = 0 - stmt.desugar_immediate_string(self) + s.desugar_immediate_string(self) + for s in stmt.desugared_output_assignments: + self.sourceref.line = s.lineno + self.sourceref.column = 0 + s.desugar_immediate_string(self) if isinstance(stmt, ParseResult.AssignmentStmt): self.sourceref.line = stmt.lineno self.sourceref.column = 0 @@ -938,7 +962,7 @@ class Parser: all_paramnames = [p[0] for p in parameters if p[0]] if len(all_paramnames) != len(set(all_paramnames)): raise self.PError("duplicates in parameter names") - results = {match.group("name") for match in re.finditer(r"\s*(?P(?:\w+)\??)\s*(?:,|$)", resultlist)} + results = [match.group("name") for match in re.finditer(r"\s*(?P(?:\w+)\??)\s*(?:,|$)", resultlist)] subroutine_block = None if code_decl: address = None @@ -1016,17 +1040,31 @@ class Parser: return varname, datatype, length, matrix_dimensions, valuetext def parse_statement(self, line: str) -> ParseResult._AstNode: + match = re.match(r"(?P.*\s*=)\s*(?P[\S]+?)\s*(?P[!]?)\s*(\((?P.*)\))?\s*$", line) + if match: + # subroutine call (not a goto) with output param assignment + preserve = not bool(match.group("fcall")) + subname = match.group("subname") + arguments = match.group("arguments") + outputs = match.group("outputs") + if outputs.strip() == "=": + raise self.PError("missing assignment target variables") + outputs = outputs.rstrip("=") + if arguments or match.group(4): + return self.parse_call_or_goto(subname, arguments, outputs, preserve, False) + # apparently it is not a call (no arguments), fall through match = re.match(r"(?Pgoto\s+)?(?P[\S]+?)\s*(?P[!]?)\s*(\((?P.*)\))?\s*$", line) if match: - # subroutine or goto call + # subroutine or goto call, without output param assignment is_goto = bool(match.group("goto")) preserve = not bool(match.group("fcall")) subname = match.group("subname") arguments = match.group("arguments") if is_goto: - return self.parse_call_or_goto(subname, arguments, preserve, True) + return self.parse_call_or_goto(subname, arguments, None, preserve, True) elif arguments or match.group(4): - return self.parse_call_or_goto(subname, arguments, preserve, False) + return self.parse_call_or_goto(subname, arguments, None, preserve, False) + # apparently it is not a call (no arguments), fall through if line == "return" or line.startswith(("return ", "return\t")): return self.parse_return(line) elif line.endswith(("++", "--")): @@ -1042,9 +1080,12 @@ class Parser: return self.parse_assignment(line) raise self.PError("invalid statement") - def parse_call_or_goto(self, targetstr: str, argumentstr: str, preserve_regs=True, is_goto=False) -> ParseResult.CallStmt: + def parse_call_or_goto(self, targetstr: str, argumentstr: str, outputstr: str, + preserve_regs=True, is_goto=False) -> ParseResult.CallStmt: argumentstr = argumentstr.strip() if argumentstr else "" + outputstr = outputstr.strip() if outputstr else "" arguments = None + outputvars = None if argumentstr: arguments = parse_arguments(argumentstr, self.sourceref) target = None # type: ParseResult.Value @@ -1080,19 +1121,35 @@ class Parser: argname = preg args_with_pnames.append((argname, value)) arguments = args_with_pnames - self.result.sub_used_by(symbol, self.sourceref) + # verify output parameters + if symbol.return_registers: + if outputstr: + outputs = [r.strip() for r in outputstr.split(",")] + if len(outputs) != len(symbol.return_registers): + raise self.PError("invalid number of output parameters consumed ({:d}, expected {:d})" + .format(len(outputs), len(symbol.return_registers))) + outputvars = list(zip(symbol.return_registers, (self.parse_expression(out) for out in outputs))) + else: + self.print_warning("warning: {}: return values discarded".format(self.sourceref)) + else: + if outputstr: + raise self.PError("this subroutine doesn't have output parameters") + self.result.sub_used_by(symbol, self.sourceref) # sub usage tracking else: + if outputstr: + raise self.PError("call cannot use output parameter assignment here, a subroutine is required for that") if arguments: - raise self.PError("call cannot take any arguments here, use a subroutine for that") - if arguments: - # verify that all arguments have gotten a name - if any(not a[0] for a in arguments): - raise self.PError("all call arguments should have a name or be matched on a named parameter") + raise self.PError("call cannot take any arguments here, a subroutine is required for that") + # verify that all arguments have gotten a name + if any(not a[0] for a in arguments or []): + raise self.PError("all call arguments should have a name or be matched on a named parameter") if isinstance(target, (type(None), ParseResult.Value)): if is_goto: - return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, is_goto=True) + return ParseResult.CallStmt(self.sourceref.line, target, address=address, + arguments=arguments, outputs=outputvars, is_goto=True) else: - return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, preserve_regs=preserve_regs) + return ParseResult.CallStmt(self.sourceref.line, target, address=address, + arguments=arguments, outputs=outputvars, preserve_regs=preserve_regs) else: raise TypeError("target should be a Value", target) @@ -1216,7 +1273,6 @@ class Parser: def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value: # parse an expression into whatever it is (primitive value, register, memory, register, etc) - # @todo only numeric expressions supported for now text = text.strip() if not text: raise self.PError("value expected") diff --git a/il65/symbols.py b/il65/symbols.py index 754169f02..f04b8874f 100644 --- a/il65/symbols.py +++ b/il65/symbols.py @@ -167,29 +167,26 @@ class ConstantDef(SymbolDefinition): class SubroutineDef(SymbolDefinition): def __init__(self, blockname: str, name: str, sourceref: SourceRef, - parameters: Sequence[Tuple[str, str]], returnvalues: Set[str], + parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str], address: Optional[int]=None, sub_block: Any=None) -> None: super().__init__(blockname, name, sourceref, False) self.address = address self.sub_block = sub_block # this is a ParseResult.Block self.parameters = parameters - self.input_registers = set() # type: Set[str] - self.return_registers = set() # type: Set[str] self.clobbered_registers = set() # type: Set[str] + self.return_registers = [] # type: List[str] # ordered! for _, param in parameters: if param in REGISTER_BYTES: - self.input_registers.add(param) self.clobbered_registers.add(param) elif param in REGISTER_WORDS: - self.input_registers.add(param[0]) - self.input_registers.add(param[1]) self.clobbered_registers.add(param[0]) self.clobbered_registers.add(param[1]) else: raise SymbolError("invalid parameter spec: " + param) for register in returnvalues: if register in REGISTER_SYMBOLS_RETURNVALUES: - self.return_registers.add(register) + self.clobbered_registers.add(register) + self.return_registers.append(register) elif register[-1] == "?": for r in register[:-1]: if r not in REGISTER_SYMBOLS_RETURNVALUES: @@ -387,7 +384,7 @@ class SymbolTable: self.eval_dict = None def define_sub(self, name: str, sourceref: SourceRef, - parameters: Sequence[Tuple[str, str]], returnvalues: Set[str], + parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str], address: Optional[int], sub_block: Any) -> None: self.check_identifier_valid(name, sourceref) self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block) diff --git a/testsource/input.ill b/testsource/input.ill index e302dad79..4f3ef13dd 100644 --- a/testsource/input.ill +++ b/testsource/input.ill @@ -3,7 +3,6 @@ output prg,sys import "c64lib" ~ main { - var .text name = "?"*80 start