mirror of
https://github.com/irmen/prog8.git
synced 2025-01-26 19:30:59 +00:00
fix return value clobbering
This commit is contained in:
parent
4a9d3200cd
commit
be76d3321b
@ -77,7 +77,7 @@ def parse_arguments(text: str, sourceref: SourceRef) -> List[Tuple[str, Primitiv
|
||||
if isinstance(node, ast.Str):
|
||||
return repr(node.s)
|
||||
if isinstance(node, ast.BinOp):
|
||||
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult):
|
||||
if node.left.id == "__ptr" and isinstance(node.op, ast.MatMult): # type: ignore
|
||||
return '#' + astnode_to_repr(node.right)
|
||||
else:
|
||||
print("error", ast.dump(node))
|
||||
|
@ -12,7 +12,7 @@ import datetime
|
||||
import subprocess
|
||||
import contextlib
|
||||
from functools import partial
|
||||
from typing import TextIO, Set, Union
|
||||
from typing import TextIO, Set, Union, List
|
||||
from .parse import ProgramFormat, ParseResult, Parser
|
||||
from .symbols import Zeropage, DataType, ConstantDef, VariableDef, SubroutineDef, \
|
||||
STRING_DATATYPES, REGISTER_WORDS, FLOAT_MAX_NEGATIVE, FLOAT_MAX_POSITIVE
|
||||
@ -438,6 +438,10 @@ class CodeGenerator:
|
||||
for assign_stmt in stmt.desugared_call_arguments:
|
||||
self.generate_assignment(assign_stmt)
|
||||
|
||||
def generate_result_assignments() -> None:
|
||||
for assign_stmt in stmt.desugared_output_assignments:
|
||||
self.generate_assignment(assign_stmt)
|
||||
|
||||
def params_load_a() -> bool:
|
||||
for assign_stmt in stmt.desugared_call_arguments:
|
||||
for lv in assign_stmt.leftvalues:
|
||||
@ -446,6 +450,16 @@ class CodeGenerator:
|
||||
return True
|
||||
return False
|
||||
|
||||
def unclobber_result_registers(registers: Set[str], output_assignments: List[ParseResult.AssignmentStmt]) -> None:
|
||||
for a in output_assignments:
|
||||
for lv in a.leftvalues:
|
||||
if isinstance(lv, ParseResult.RegisterValue):
|
||||
if len(lv.register) == 1:
|
||||
registers.remove(lv.register)
|
||||
else:
|
||||
for r in lv.register:
|
||||
registers.remove(r)
|
||||
|
||||
if stmt.target.name:
|
||||
symblock, targetdef = self.cur_block.lookup(stmt.target.name)
|
||||
else:
|
||||
@ -459,14 +473,17 @@ class CodeGenerator:
|
||||
if stmt.is_goto:
|
||||
generate_param_assignments()
|
||||
self.p("\t\tjmp " + targetstr)
|
||||
# no result assignments because it's a goto
|
||||
return
|
||||
clobbered = set() # type: Set[str]
|
||||
if targetdef.clobbered_registers:
|
||||
if stmt.preserve_regs:
|
||||
clobbered = targetdef.clobbered_registers
|
||||
unclobber_result_registers(clobbered, stmt.desugared_output_assignments)
|
||||
with self.preserving_registers(clobbered, loads_a_within=params_load_a()):
|
||||
generate_param_assignments()
|
||||
self.p("\t\tjsr " + targetstr)
|
||||
generate_result_assignments()
|
||||
return
|
||||
if isinstance(stmt.target, ParseResult.IndirectValue):
|
||||
if stmt.target.name:
|
||||
@ -490,8 +507,10 @@ class CodeGenerator:
|
||||
self.p("\t\tjmp ({:s})".format(Parser.to_hex(Zeropage.SCRATCH_B1)))
|
||||
else:
|
||||
self.p("\t\tjmp ({:s})".format(targetstr))
|
||||
# no result assignments because it's a goto
|
||||
else:
|
||||
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
|
||||
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
|
||||
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
|
||||
generate_param_assignments()
|
||||
if targetstr in REGISTER_WORDS:
|
||||
@ -519,6 +538,7 @@ class CodeGenerator:
|
||||
self.p("\t\tjmp ++")
|
||||
self.p("+\t\tjmp ({:s})".format(targetstr))
|
||||
self.p("+")
|
||||
generate_result_assignments()
|
||||
else:
|
||||
if stmt.target.name:
|
||||
targetstr = stmt.target.name
|
||||
@ -532,11 +552,14 @@ class CodeGenerator:
|
||||
# no need to preserve registers for a goto
|
||||
generate_param_assignments()
|
||||
self.p("\t\tjmp " + targetstr)
|
||||
# no result assignments because it's a goto
|
||||
else:
|
||||
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
|
||||
unclobber_result_registers(preserve_regs, stmt.desugared_output_assignments)
|
||||
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
|
||||
generate_param_assignments()
|
||||
self.p("\t\tjsr " + targetstr)
|
||||
generate_result_assignments()
|
||||
|
||||
def generate_assignment(self, stmt: ParseResult.AssignmentStmt) -> None:
|
||||
def unwrap_indirect(iv: ParseResult.IndirectValue) -> ParseResult.MemMappedValue:
|
||||
|
132
il65/parse.py
132
il65/parse.py
@ -9,10 +9,11 @@ License: GNU GPL 3.0, see LICENSE
|
||||
import math
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import enum
|
||||
from collections import defaultdict
|
||||
from typing import Set, List, Tuple, Optional, Any, Dict, Union
|
||||
from typing import Set, List, Tuple, Optional, Any, Dict, Union, Sequence
|
||||
from .astparse import ParseError, parse_expr_as_int, parse_expr_as_number, parse_expr_as_primitive,\
|
||||
parse_expr_as_string, parse_arguments
|
||||
from .symbols import SourceRef, SymbolTable, DataType, SymbolDefinition, SubroutineDef, LabelDef, \
|
||||
@ -409,24 +410,32 @@ class ParseResult:
|
||||
class CallStmt(_AstNode):
|
||||
def __init__(self, lineno: int, target: Optional['ParseResult.Value']=None, *,
|
||||
address: Optional[int]=None, arguments: List[Tuple[str, Any]]=None,
|
||||
is_goto: bool=False, preserve_regs: bool=True) -> None:
|
||||
outputs: List[Tuple[str, 'ParseResult.Value']]=None, is_goto: bool=False, preserve_regs: bool=True) -> None:
|
||||
self.lineno = lineno
|
||||
self.target = target
|
||||
self.address = address
|
||||
self.arguments = arguments
|
||||
self.outputvars = outputs
|
||||
self.is_goto = is_goto
|
||||
self.preserve_regs = preserve_regs
|
||||
self.desugared_call_arguments = [] # type: List[ParseResult.AssignmentStmt]
|
||||
self.desugared_output_assignments = [] # type: List[ParseResult.AssignmentStmt]
|
||||
|
||||
def desugar_call_arguments(self, parser: 'Parser') -> None:
|
||||
if self.arguments:
|
||||
self.desugared_call_arguments.clear()
|
||||
for name, value in self.arguments:
|
||||
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
|
||||
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
|
||||
if not assignment.is_identity():
|
||||
assignment.lineno = self.lineno
|
||||
self.desugared_call_arguments.append(assignment)
|
||||
def desugar_call_arguments_and_outputs(self, parser: 'Parser') -> None:
|
||||
self.desugared_call_arguments.clear()
|
||||
self.desugared_output_assignments.clear()
|
||||
for name, value in self.arguments or []:
|
||||
assert name is not None, "all call arguments should have a name or be matched on a named parameter"
|
||||
assignment = parser.parse_assignment("{:s}={:s}".format(name, value))
|
||||
if not assignment.is_identity():
|
||||
assignment.lineno = self.lineno
|
||||
self.desugared_call_arguments.append(assignment)
|
||||
for register, value in self.outputvars or []:
|
||||
rvalue = parser.parse_expression(register)
|
||||
assignment = ParseResult.AssignmentStmt([value], rvalue, self.lineno)
|
||||
# note: we need the identity assignment here or the output register handling generates buggy code
|
||||
assignment.lineno = self.lineno
|
||||
self.desugared_output_assignments.append(assignment)
|
||||
|
||||
class InlineAsm(_AstNode):
|
||||
def __init__(self, lineno: int, asmlines: List[str]) -> None:
|
||||
@ -502,19 +511,27 @@ class Parser:
|
||||
try:
|
||||
return self.parse_file()
|
||||
except ParseError as x:
|
||||
print()
|
||||
if sys.stderr.isatty():
|
||||
print("\x1b[1m", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
if x.sourcetext:
|
||||
print("\tsource text: '{:s}'".format(x.sourcetext))
|
||||
print("\tsource text: '{:s}'".format(x.sourcetext), file=sys.stderr)
|
||||
if x.sourceref.column:
|
||||
print("\t" + ' '*x.sourceref.column + ' ^')
|
||||
print("\t" + ' '*x.sourceref.column + ' ^', file=sys.stderr)
|
||||
if self.parsing_import:
|
||||
print("Error (in imported file):", str(x))
|
||||
print("Error (in imported file):", str(x), file=sys.stderr)
|
||||
else:
|
||||
print("Error:", str(x))
|
||||
print("Error:", str(x), file=sys.stderr)
|
||||
if sys.stderr.isatty():
|
||||
print("\x1b[0m", file=sys.stderr)
|
||||
raise # XXX temporary solution to get stack trace info in the event of parse errors
|
||||
except Exception as x:
|
||||
print("\nERROR: internal parser error: ", x)
|
||||
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line)
|
||||
if sys.stderr.isatty():
|
||||
print("\x1b[1m", file=sys.stderr)
|
||||
print("\nERROR: internal parser error: ", x, file=sys.stderr)
|
||||
print(" file:", self.sourceref.file, "block:", self.cur_block.name, "line:", self.sourceref.line, file=sys.stderr)
|
||||
if sys.stderr.isatty():
|
||||
print("\x1b[0m", file=sys.stderr)
|
||||
raise # XXX temporary solution to get stack trace info in the event of parse errors
|
||||
|
||||
def parse_file(self) -> ParseResult:
|
||||
@ -524,7 +541,10 @@ class Parser:
|
||||
return self.result
|
||||
|
||||
def print_warning(self, text: str) -> None:
|
||||
print(text)
|
||||
if sys.stdout.isatty():
|
||||
print("\x1b[1m" + text + "\x1b[0m")
|
||||
else:
|
||||
print(text)
|
||||
|
||||
def _parse_comments(self) -> None:
|
||||
while True:
|
||||
@ -604,22 +624,26 @@ class Parser:
|
||||
if isinstance(stmt, ParseResult.CallStmt):
|
||||
self.sourceref.line = stmt.lineno
|
||||
self.sourceref.column = 0
|
||||
stmt.desugar_call_arguments(self)
|
||||
stmt.desugar_call_arguments_and_outputs(self)
|
||||
# create parameter loads for calls, in subroutine blocks
|
||||
for sub in block.symbols.iter_subroutines(True):
|
||||
for stmt in sub.sub_block.statements:
|
||||
if isinstance(stmt, ParseResult.CallStmt):
|
||||
self.sourceref.line = stmt.lineno
|
||||
self.sourceref.column = 0
|
||||
stmt.desugar_call_arguments(self)
|
||||
stmt.desugar_call_arguments_and_outputs(self)
|
||||
block.flatten_statement_list()
|
||||
# desugar immediate string value assignments
|
||||
for index, stmt in enumerate(list(block.statements)):
|
||||
if isinstance(stmt, ParseResult.CallStmt):
|
||||
for stmt in stmt.desugared_call_arguments:
|
||||
self.sourceref.line = stmt.lineno
|
||||
for s in stmt.desugared_call_arguments:
|
||||
self.sourceref.line = s.lineno
|
||||
self.sourceref.column = 0
|
||||
stmt.desugar_immediate_string(self)
|
||||
s.desugar_immediate_string(self)
|
||||
for s in stmt.desugared_output_assignments:
|
||||
self.sourceref.line = s.lineno
|
||||
self.sourceref.column = 0
|
||||
s.desugar_immediate_string(self)
|
||||
if isinstance(stmt, ParseResult.AssignmentStmt):
|
||||
self.sourceref.line = stmt.lineno
|
||||
self.sourceref.column = 0
|
||||
@ -938,7 +962,7 @@ class Parser:
|
||||
all_paramnames = [p[0] for p in parameters if p[0]]
|
||||
if len(all_paramnames) != len(set(all_paramnames)):
|
||||
raise self.PError("duplicates in parameter names")
|
||||
results = {match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)}
|
||||
results = [match.group("name") for match in re.finditer(r"\s*(?P<name>(?:\w+)\??)\s*(?:,|$)", resultlist)]
|
||||
subroutine_block = None
|
||||
if code_decl:
|
||||
address = None
|
||||
@ -1016,17 +1040,31 @@ class Parser:
|
||||
return varname, datatype, length, matrix_dimensions, valuetext
|
||||
|
||||
def parse_statement(self, line: str) -> ParseResult._AstNode:
|
||||
match = re.match(r"(?P<outputs>.*\s*=)\s*(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
|
||||
if match:
|
||||
# subroutine call (not a goto) with output param assignment
|
||||
preserve = not bool(match.group("fcall"))
|
||||
subname = match.group("subname")
|
||||
arguments = match.group("arguments")
|
||||
outputs = match.group("outputs")
|
||||
if outputs.strip() == "=":
|
||||
raise self.PError("missing assignment target variables")
|
||||
outputs = outputs.rstrip("=")
|
||||
if arguments or match.group(4):
|
||||
return self.parse_call_or_goto(subname, arguments, outputs, preserve, False)
|
||||
# apparently it is not a call (no arguments), fall through
|
||||
match = re.match(r"(?P<goto>goto\s+)?(?P<subname>[\S]+?)\s*(?P<fcall>[!]?)\s*(\((?P<arguments>.*)\))?\s*$", line)
|
||||
if match:
|
||||
# subroutine or goto call
|
||||
# subroutine or goto call, without output param assignment
|
||||
is_goto = bool(match.group("goto"))
|
||||
preserve = not bool(match.group("fcall"))
|
||||
subname = match.group("subname")
|
||||
arguments = match.group("arguments")
|
||||
if is_goto:
|
||||
return self.parse_call_or_goto(subname, arguments, preserve, True)
|
||||
return self.parse_call_or_goto(subname, arguments, None, preserve, True)
|
||||
elif arguments or match.group(4):
|
||||
return self.parse_call_or_goto(subname, arguments, preserve, False)
|
||||
return self.parse_call_or_goto(subname, arguments, None, preserve, False)
|
||||
# apparently it is not a call (no arguments), fall through
|
||||
if line == "return" or line.startswith(("return ", "return\t")):
|
||||
return self.parse_return(line)
|
||||
elif line.endswith(("++", "--")):
|
||||
@ -1042,9 +1080,12 @@ class Parser:
|
||||
return self.parse_assignment(line)
|
||||
raise self.PError("invalid statement")
|
||||
|
||||
def parse_call_or_goto(self, targetstr: str, argumentstr: str, preserve_regs=True, is_goto=False) -> ParseResult.CallStmt:
|
||||
def parse_call_or_goto(self, targetstr: str, argumentstr: str, outputstr: str,
|
||||
preserve_regs=True, is_goto=False) -> ParseResult.CallStmt:
|
||||
argumentstr = argumentstr.strip() if argumentstr else ""
|
||||
outputstr = outputstr.strip() if outputstr else ""
|
||||
arguments = None
|
||||
outputvars = None
|
||||
if argumentstr:
|
||||
arguments = parse_arguments(argumentstr, self.sourceref)
|
||||
target = None # type: ParseResult.Value
|
||||
@ -1080,19 +1121,35 @@ class Parser:
|
||||
argname = preg
|
||||
args_with_pnames.append((argname, value))
|
||||
arguments = args_with_pnames
|
||||
self.result.sub_used_by(symbol, self.sourceref)
|
||||
# verify output parameters
|
||||
if symbol.return_registers:
|
||||
if outputstr:
|
||||
outputs = [r.strip() for r in outputstr.split(",")]
|
||||
if len(outputs) != len(symbol.return_registers):
|
||||
raise self.PError("invalid number of output parameters consumed ({:d}, expected {:d})"
|
||||
.format(len(outputs), len(symbol.return_registers)))
|
||||
outputvars = list(zip(symbol.return_registers, (self.parse_expression(out) for out in outputs)))
|
||||
else:
|
||||
self.print_warning("warning: {}: return values discarded".format(self.sourceref))
|
||||
else:
|
||||
if outputstr:
|
||||
raise self.PError("this subroutine doesn't have output parameters")
|
||||
self.result.sub_used_by(symbol, self.sourceref) # sub usage tracking
|
||||
else:
|
||||
if outputstr:
|
||||
raise self.PError("call cannot use output parameter assignment here, a subroutine is required for that")
|
||||
if arguments:
|
||||
raise self.PError("call cannot take any arguments here, use a subroutine for that")
|
||||
if arguments:
|
||||
# verify that all arguments have gotten a name
|
||||
if any(not a[0] for a in arguments):
|
||||
raise self.PError("all call arguments should have a name or be matched on a named parameter")
|
||||
raise self.PError("call cannot take any arguments here, a subroutine is required for that")
|
||||
# verify that all arguments have gotten a name
|
||||
if any(not a[0] for a in arguments or []):
|
||||
raise self.PError("all call arguments should have a name or be matched on a named parameter")
|
||||
if isinstance(target, (type(None), ParseResult.Value)):
|
||||
if is_goto:
|
||||
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, is_goto=True)
|
||||
return ParseResult.CallStmt(self.sourceref.line, target, address=address,
|
||||
arguments=arguments, outputs=outputvars, is_goto=True)
|
||||
else:
|
||||
return ParseResult.CallStmt(self.sourceref.line, target, address=address, arguments=arguments, preserve_regs=preserve_regs)
|
||||
return ParseResult.CallStmt(self.sourceref.line, target, address=address,
|
||||
arguments=arguments, outputs=outputvars, preserve_regs=preserve_regs)
|
||||
else:
|
||||
raise TypeError("target should be a Value", target)
|
||||
|
||||
@ -1216,7 +1273,6 @@ class Parser:
|
||||
|
||||
def parse_expression(self, text: str, is_indirect=False) -> ParseResult.Value:
|
||||
# parse an expression into whatever it is (primitive value, register, memory, register, etc)
|
||||
# @todo only numeric expressions supported for now
|
||||
text = text.strip()
|
||||
if not text:
|
||||
raise self.PError("value expected")
|
||||
|
@ -167,29 +167,26 @@ class ConstantDef(SymbolDefinition):
|
||||
|
||||
class SubroutineDef(SymbolDefinition):
|
||||
def __init__(self, blockname: str, name: str, sourceref: SourceRef,
|
||||
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str],
|
||||
parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
|
||||
address: Optional[int]=None, sub_block: Any=None) -> None:
|
||||
super().__init__(blockname, name, sourceref, False)
|
||||
self.address = address
|
||||
self.sub_block = sub_block # this is a ParseResult.Block
|
||||
self.parameters = parameters
|
||||
self.input_registers = set() # type: Set[str]
|
||||
self.return_registers = set() # type: Set[str]
|
||||
self.clobbered_registers = set() # type: Set[str]
|
||||
self.return_registers = [] # type: List[str] # ordered!
|
||||
for _, param in parameters:
|
||||
if param in REGISTER_BYTES:
|
||||
self.input_registers.add(param)
|
||||
self.clobbered_registers.add(param)
|
||||
elif param in REGISTER_WORDS:
|
||||
self.input_registers.add(param[0])
|
||||
self.input_registers.add(param[1])
|
||||
self.clobbered_registers.add(param[0])
|
||||
self.clobbered_registers.add(param[1])
|
||||
else:
|
||||
raise SymbolError("invalid parameter spec: " + param)
|
||||
for register in returnvalues:
|
||||
if register in REGISTER_SYMBOLS_RETURNVALUES:
|
||||
self.return_registers.add(register)
|
||||
self.clobbered_registers.add(register)
|
||||
self.return_registers.append(register)
|
||||
elif register[-1] == "?":
|
||||
for r in register[:-1]:
|
||||
if r not in REGISTER_SYMBOLS_RETURNVALUES:
|
||||
@ -387,7 +384,7 @@ class SymbolTable:
|
||||
self.eval_dict = None
|
||||
|
||||
def define_sub(self, name: str, sourceref: SourceRef,
|
||||
parameters: Sequence[Tuple[str, str]], returnvalues: Set[str],
|
||||
parameters: Sequence[Tuple[str, str]], returnvalues: Sequence[str],
|
||||
address: Optional[int], sub_block: Any) -> None:
|
||||
self.check_identifier_valid(name, sourceref)
|
||||
self.symbols[name] = SubroutineDef(self.name, name, sourceref, parameters, returnvalues, address, sub_block)
|
||||
|
@ -3,7 +3,6 @@ output prg,sys
|
||||
import "c64lib"
|
||||
|
||||
~ main {
|
||||
|
||||
var .text name = "?"*80
|
||||
|
||||
start
|
||||
|
Loading…
x
Reference in New Issue
Block a user