avoid needless saving of A register

This commit is contained in:
Irmen de Jong 2017-12-25 19:22:22 +01:00
parent afaf8e9beb
commit a0a561cfb6
2 changed files with 31 additions and 22 deletions

View File

@ -6,7 +6,6 @@ Written by Irmen de Jong (irmen@razorvine.net)
License: GNU GPL 3.0, see LICENSE License: GNU GPL 3.0, see LICENSE
""" """
import sys
import io import io
import math import math
import datetime import datetime
@ -435,9 +434,18 @@ class CodeGenerator:
raise NotImplementedError("decr by > 1") # XXX raise NotImplementedError("decr by > 1") # XXX
def generate_call(self, stmt: ParseResult.CallStmt) -> None: def generate_call(self, stmt: ParseResult.CallStmt) -> None:
def generate_param_assignments(): def generate_param_assignments() -> None:
for assign_stmt in stmt.desugared_call_arguments: for assign_stmt in stmt.desugared_call_arguments:
self.generate_assignment(assign_stmt) self.generate_assignment(assign_stmt)
def params_load_a() -> bool:
for assign_stmt in stmt.desugared_call_arguments:
for lv in assign_stmt.leftvalues:
if isinstance(lv, ParseResult.RegisterValue):
if lv.register == 'A':
return True
return False
if stmt.target.name: if stmt.target.name:
symblock, targetdef = self.cur_block.lookup(stmt.target.name) symblock, targetdef = self.cur_block.lookup(stmt.target.name)
else: else:
@ -456,7 +464,7 @@ class CodeGenerator:
if targetdef.clobbered_registers: if targetdef.clobbered_registers:
if stmt.preserve_regs: if stmt.preserve_regs:
clobbered = targetdef.clobbered_registers clobbered = targetdef.clobbered_registers
with self.preserving_registers(clobbered): with self.preserving_registers(clobbered, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
self.p("\t\tjsr " + targetstr) self.p("\t\tjsr " + targetstr)
return return
@ -484,7 +492,7 @@ class CodeGenerator:
self.p("\t\tjmp ({:s})".format(targetstr)) self.p("\t\tjmp ({:s})".format(targetstr))
else: else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
with self.preserving_registers(preserve_regs): with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
if targetstr in REGISTER_WORDS: if targetstr in REGISTER_WORDS:
if stmt.preserve_regs: if stmt.preserve_regs:
@ -526,7 +534,7 @@ class CodeGenerator:
self.p("\t\tjmp " + targetstr) self.p("\t\tjmp " + targetstr)
else: else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set() preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
with self.preserving_registers(preserve_regs): with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments() generate_param_assignments()
self.p("\t\tjsr " + targetstr) self.p("\t\tjsr " + targetstr)
@ -635,7 +643,7 @@ class CodeGenerator:
elif lv.datatype == DataType.WORD: elif lv.datatype == DataType.WORD:
if len(r_register) == 1: if len(r_register) == 1:
self.p("\t\tst{:s} {}".format(r_register.lower(), lv_string)) # lsb self.p("\t\tst{:s} {}".format(r_register.lower(), lv_string)) # lsb
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #0") self.p("\t\tlda #0")
self.p("\t\tsta {:s}+1".format(lv_string)) # msb self.p("\t\tsta {:s}+1".format(lv_string)) # msb
else: else:
@ -657,7 +665,7 @@ class CodeGenerator:
self.p("\t\tpha\n\t\ttxa\n\t\ttay\n\t\tpla") # X->Y (so we have AY now) self.p("\t\tpha\n\t\ttxa\n\t\ttay\n\t\tpla") # X->Y (so we have AY now)
do_rom_calls() do_rom_calls()
else: # XY else: # XY
with self.preserving_registers({'A', 'X', 'Y'}): with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\ttxa") # X->A (so we have AY now) self.p("\t\ttxa") # X->A (so we have AY now)
do_rom_calls() do_rom_calls()
elif r_register in "AXY": elif r_register in "AXY":
@ -673,7 +681,7 @@ class CodeGenerator:
self.p("\t\ttay") self.p("\t\ttay")
do_rom_calls() do_rom_calls()
elif r_register == "X": elif r_register == "X":
with self.preserving_registers({'A', 'X', 'Y'}): with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\ttxa") self.p("\t\ttxa")
self.p("\t\ttay") self.p("\t\ttay")
do_rom_calls() do_rom_calls()
@ -750,23 +758,24 @@ class CodeGenerator:
raise CodeError("invalid register " + lv.register) raise CodeError("invalid register " + lv.register)
@contextlib.contextmanager @contextlib.contextmanager
def preserving_registers(self, registers: Set[str]): def preserving_registers(self, registers: Set[str], loads_a_within: bool=False):
# @todo option to avoid the sta $03/lda$03 when a is loaded anyway # this clobbers a ZP scratch register and is therefore NOT safe to use in interrupts
# this clobbers a ZP scratch register and is therefore safe to use in interrupts
# see http://6502.org/tutorials/register_preservation.html # see http://6502.org/tutorials/register_preservation.html
if registers == {'A'}: if registers == {'A'}:
self.p("\t\tpha") self.p("\t\tpha")
yield yield
self.p("\t\tpla") self.p("\t\tpla")
elif registers: elif registers:
self.p("\t\tsta ${:02x}".format(Zeropage.SCRATCH_B2)) if not loads_a_within:
self.p("\t\tsta ${:02x}".format(Zeropage.SCRATCH_B2))
if 'A' in registers: if 'A' in registers:
self.p("\t\tpha") self.p("\t\tpha")
if 'X' in registers: if 'X' in registers:
self.p("\t\ttxa\n\t\tpha") self.p("\t\ttxa\n\t\tpha")
if 'Y' in registers: if 'Y' in registers:
self.p("\t\ttya\n\t\tpha") self.p("\t\ttya\n\t\tpha")
self.p("\t\tlda ${:02x}".format(Zeropage.SCRATCH_B2)) if not loads_a_within:
self.p("\t\tlda ${:02x}".format(Zeropage.SCRATCH_B2))
yield yield
if 'Y' in registers: if 'Y' in registers:
self.p("\t\tpla\n\t\ttay") self.p("\t\tpla\n\t\ttay")
@ -791,13 +800,13 @@ class CodeGenerator:
if lvdatatype == DataType.BYTE: if lvdatatype == DataType.BYTE:
if rvalue.value is not None and not lv.assignable_from(rvalue) or rvalue.datatype != DataType.BYTE: if rvalue.value is not None and not lv.assignable_from(rvalue) or rvalue.datatype != DataType.BYTE:
raise OverflowError("value doesn't fit in a byte") raise OverflowError("value doesn't fit in a byte")
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #" + r_str) self.p("\t\tlda #" + r_str)
self.p("\t\tsta " + assign_target) self.p("\t\tsta " + assign_target)
elif lvdatatype == DataType.WORD: elif lvdatatype == DataType.WORD:
if rvalue.value is not None and not lv.assignable_from(rvalue): if rvalue.value is not None and not lv.assignable_from(rvalue):
raise OverflowError("value doesn't fit in a word") raise OverflowError("value doesn't fit in a word")
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #<" + r_str) self.p("\t\tlda #<" + r_str)
self.p("\t\tsta " + assign_target) self.p("\t\tsta " + assign_target)
self.p("\t\tlda #>" + r_str) self.p("\t\tlda #>" + r_str)
@ -831,18 +840,18 @@ class CodeGenerator:
if lv.datatype == DataType.BYTE: if lv.datatype == DataType.BYTE:
if rvalue.datatype != DataType.BYTE: if rvalue.datatype != DataType.BYTE:
raise CodeError("can only assign a byte to a byte", str(rvalue)) raise CodeError("can only assign a byte to a byte", str(rvalue))
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str) self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str) self.p("\t\tsta " + l_str)
elif lv.datatype == DataType.WORD: elif lv.datatype == DataType.WORD:
if rvalue.datatype == DataType.BYTE: if rvalue.datatype == DataType.BYTE:
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str) self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str) self.p("\t\tsta " + l_str)
self.p("\t\tlda #0") self.p("\t\tlda #0")
self.p("\t\tsta {:s}+1".format(l_str)) self.p("\t\tsta {:s}+1".format(l_str))
elif rvalue.datatype == DataType.WORD: elif rvalue.datatype == DataType.WORD:
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda {:s}".format(r_str)) self.p("\t\tlda {:s}".format(r_str))
self.p("\t\tsta {:s}".format(l_str)) self.p("\t\tsta {:s}".format(l_str))
self.p("\t\tlda {:s}+1".format(r_str)) self.p("\t\tlda {:s}+1".format(r_str))
@ -851,7 +860,7 @@ class CodeGenerator:
raise CodeError("can only assign a byte or word to a word", str(rvalue)) raise CodeError("can only assign a byte or word to a word", str(rvalue))
elif lv.datatype == DataType.FLOAT: elif lv.datatype == DataType.FLOAT:
if rvalue.datatype == DataType.FLOAT: if rvalue.datatype == DataType.FLOAT:
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str) self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str) self.p("\t\tsta " + l_str)
self.p("\t\tlda {:s}+1".format(r_str)) self.p("\t\tlda {:s}+1".format(r_str))
@ -870,7 +879,7 @@ class CodeGenerator:
self.p("\t\tldy #>" + l_str) self.p("\t\tldy #>" + l_str)
self.p("\t\tjsr c64.FTOMEMXY") # fac1 -> memory XY self.p("\t\tjsr c64.FTOMEMXY") # fac1 -> memory XY
elif rvalue.datatype == DataType.WORD: elif rvalue.datatype == DataType.WORD:
with self.preserving_registers({'A', 'X', 'Y'}): with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\tlda " + r_str) self.p("\t\tlda " + r_str)
self.p("\t\tldy {:s}+1".format(r_str)) self.p("\t\tldy {:s}+1".format(r_str))
self.p("\t\tjsr c64util.GIVUAYF") # uword AY -> fac1 self.p("\t\tjsr c64util.GIVUAYF") # uword AY -> fac1
@ -884,7 +893,7 @@ class CodeGenerator:
def generate_assign_char_to_memory(self, lv: ParseResult.MemMappedValue, char_str: str) -> None: def generate_assign_char_to_memory(self, lv: ParseResult.MemMappedValue, char_str: str) -> None:
# Memory = Character # Memory = Character
with self.preserving_registers({'A'}): with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #" + char_str) self.p("\t\tlda #" + char_str)
if not lv.name: if not lv.name:
self.p("\t\tsta " + Parser.to_hex(lv.address)) self.p("\t\tsta " + Parser.to_hex(lv.address))

View File

@ -1176,7 +1176,7 @@ class Parser:
except ParseError: except ParseError:
pass pass
else: else:
if symbol: if isinstance(symbol, SubroutineDef):
self.result.sub_used_by(symbol, self.sourceref) self.result.sub_used_by(symbol, self.sourceref)
asmlines.append(line) asmlines.append(line)