avoid needless saving of A register

This commit is contained in:
Irmen de Jong 2017-12-25 19:22:22 +01:00
parent afaf8e9beb
commit a0a561cfb6
2 changed files with 31 additions and 22 deletions

View File

@ -6,7 +6,6 @@ Written by Irmen de Jong (irmen@razorvine.net)
License: GNU GPL 3.0, see LICENSE
"""
import sys
import io
import math
import datetime
@ -435,9 +434,18 @@ class CodeGenerator:
raise NotImplementedError("decr by > 1") # XXX
def generate_call(self, stmt: ParseResult.CallStmt) -> None:
def generate_param_assignments():
def generate_param_assignments() -> None:
for assign_stmt in stmt.desugared_call_arguments:
self.generate_assignment(assign_stmt)
def params_load_a() -> bool:
for assign_stmt in stmt.desugared_call_arguments:
for lv in assign_stmt.leftvalues:
if isinstance(lv, ParseResult.RegisterValue):
if lv.register == 'A':
return True
return False
if stmt.target.name:
symblock, targetdef = self.cur_block.lookup(stmt.target.name)
else:
@ -456,7 +464,7 @@ class CodeGenerator:
if targetdef.clobbered_registers:
if stmt.preserve_regs:
clobbered = targetdef.clobbered_registers
with self.preserving_registers(clobbered):
with self.preserving_registers(clobbered, loads_a_within=params_load_a()):
generate_param_assignments()
self.p("\t\tjsr " + targetstr)
return
@ -484,7 +492,7 @@ class CodeGenerator:
self.p("\t\tjmp ({:s})".format(targetstr))
else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
with self.preserving_registers(preserve_regs):
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments()
if targetstr in REGISTER_WORDS:
if stmt.preserve_regs:
@ -526,7 +534,7 @@ class CodeGenerator:
self.p("\t\tjmp " + targetstr)
else:
preserve_regs = {'A', 'X', 'Y'} if stmt.preserve_regs else set()
with self.preserving_registers(preserve_regs):
with self.preserving_registers(preserve_regs, loads_a_within=params_load_a()):
generate_param_assignments()
self.p("\t\tjsr " + targetstr)
@ -635,7 +643,7 @@ class CodeGenerator:
elif lv.datatype == DataType.WORD:
if len(r_register) == 1:
self.p("\t\tst{:s} {}".format(r_register.lower(), lv_string)) # lsb
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #0")
self.p("\t\tsta {:s}+1".format(lv_string)) # msb
else:
@ -657,7 +665,7 @@ class CodeGenerator:
self.p("\t\tpha\n\t\ttxa\n\t\ttay\n\t\tpla") # X->Y (so we have AY now)
do_rom_calls()
else: # XY
with self.preserving_registers({'A', 'X', 'Y'}):
with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\ttxa") # X->A (so we have AY now)
do_rom_calls()
elif r_register in "AXY":
@ -673,7 +681,7 @@ class CodeGenerator:
self.p("\t\ttay")
do_rom_calls()
elif r_register == "X":
with self.preserving_registers({'A', 'X', 'Y'}):
with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\ttxa")
self.p("\t\ttay")
do_rom_calls()
@ -750,15 +758,15 @@ class CodeGenerator:
raise CodeError("invalid register " + lv.register)
@contextlib.contextmanager
def preserving_registers(self, registers: Set[str]):
# @todo option to avoid the sta $03/lda$03 when a is loaded anyway
# this clobbers a ZP scratch register and is therefore safe to use in interrupts
def preserving_registers(self, registers: Set[str], loads_a_within: bool=False):
# this clobbers a ZP scratch register and is therefore NOT safe to use in interrupts
# see http://6502.org/tutorials/register_preservation.html
if registers == {'A'}:
self.p("\t\tpha")
yield
self.p("\t\tpla")
elif registers:
if not loads_a_within:
self.p("\t\tsta ${:02x}".format(Zeropage.SCRATCH_B2))
if 'A' in registers:
self.p("\t\tpha")
@ -766,6 +774,7 @@ class CodeGenerator:
self.p("\t\ttxa\n\t\tpha")
if 'Y' in registers:
self.p("\t\ttya\n\t\tpha")
if not loads_a_within:
self.p("\t\tlda ${:02x}".format(Zeropage.SCRATCH_B2))
yield
if 'Y' in registers:
@ -791,13 +800,13 @@ class CodeGenerator:
if lvdatatype == DataType.BYTE:
if rvalue.value is not None and not lv.assignable_from(rvalue) or rvalue.datatype != DataType.BYTE:
raise OverflowError("value doesn't fit in a byte")
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #" + r_str)
self.p("\t\tsta " + assign_target)
elif lvdatatype == DataType.WORD:
if rvalue.value is not None and not lv.assignable_from(rvalue):
raise OverflowError("value doesn't fit in a word")
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #<" + r_str)
self.p("\t\tsta " + assign_target)
self.p("\t\tlda #>" + r_str)
@ -831,18 +840,18 @@ class CodeGenerator:
if lv.datatype == DataType.BYTE:
if rvalue.datatype != DataType.BYTE:
raise CodeError("can only assign a byte to a byte", str(rvalue))
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str)
elif lv.datatype == DataType.WORD:
if rvalue.datatype == DataType.BYTE:
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str)
self.p("\t\tlda #0")
self.p("\t\tsta {:s}+1".format(l_str))
elif rvalue.datatype == DataType.WORD:
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda {:s}".format(r_str))
self.p("\t\tsta {:s}".format(l_str))
self.p("\t\tlda {:s}+1".format(r_str))
@ -851,7 +860,7 @@ class CodeGenerator:
raise CodeError("can only assign a byte or word to a word", str(rvalue))
elif lv.datatype == DataType.FLOAT:
if rvalue.datatype == DataType.FLOAT:
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda " + r_str)
self.p("\t\tsta " + l_str)
self.p("\t\tlda {:s}+1".format(r_str))
@ -870,7 +879,7 @@ class CodeGenerator:
self.p("\t\tldy #>" + l_str)
self.p("\t\tjsr c64.FTOMEMXY") # fac1 -> memory XY
elif rvalue.datatype == DataType.WORD:
with self.preserving_registers({'A', 'X', 'Y'}):
with self.preserving_registers({'A', 'X', 'Y'}, loads_a_within=True):
self.p("\t\tlda " + r_str)
self.p("\t\tldy {:s}+1".format(r_str))
self.p("\t\tjsr c64util.GIVUAYF") # uword AY -> fac1
@ -884,7 +893,7 @@ class CodeGenerator:
def generate_assign_char_to_memory(self, lv: ParseResult.MemMappedValue, char_str: str) -> None:
# Memory = Character
with self.preserving_registers({'A'}):
with self.preserving_registers({'A'}, loads_a_within=True):
self.p("\t\tlda #" + char_str)
if not lv.name:
self.p("\t\tsta " + Parser.to_hex(lv.address))

View File

@ -1176,7 +1176,7 @@ class Parser:
except ParseError:
pass
else:
if symbol:
if isinstance(symbol, SubroutineDef):
self.result.sub_used_by(symbol, self.sourceref)
asmlines.append(line)