From 2a5933e115c7a4af2be982d1c58fc09a3b5852fb Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Wed, 13 Jan 2021 14:18:28 +0100 Subject: [PATCH] 6502: Fix sbyte to word promotion in some contexts --- .../millfork/assembly/mos/AssemblyLine.scala | 28 +++++++--- .../compiler/mos/MosExpressionCompiler.scala | 9 +++- .../millfork/test/SignExtensionSuite.scala | 15 ++++++ .../scala/millfork/test/WordMathSuite.scala | 52 +++++++++++++++++++ 4 files changed, 95 insertions(+), 9 deletions(-) diff --git a/src/main/scala/millfork/assembly/mos/AssemblyLine.scala b/src/main/scala/millfork/assembly/mos/AssemblyLine.scala index 5b3c3ece..c3933e2b 100644 --- a/src/main/scala/millfork/assembly/mos/AssemblyLine.scala +++ b/src/main/scala/millfork/assembly/mos/AssemblyLine.scala @@ -422,7 +422,7 @@ object AssemblyLine { private val opcodesForZeroedVariableOperation = Set(ADC, EOR, ORA, AND, SBC, CMP, CPX, CPY) private val opcodesForZeroedOrSignExtendedVariableOperation = Set(LDA, LDX, LDY, LDZ) - def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] = + def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0, preserveA: Boolean = false): List[AssemblyLine] = if (offset >= variable.typ.size) { if (opcodesForNopVariableOperation(opcode)) { Nil @@ -432,16 +432,28 @@ object AssemblyLine { } else if (opcodesForZeroedOrSignExtendedVariableOperation(opcode)) { if (variable.typ.isSigned) { val label = ctx.nextLabel("sx") - AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) ++ List( + val loadHiByteToA = AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) + val signExtend = List( AssemblyLine.immediate(ORA, 0x7f), AssemblyLine.relative(BMI, label), AssemblyLine.immediate(LDA, 0), - AssemblyLine.label(label)) ++ (opcode match { - case LDA => Nil - case LDX | LAX => List(AssemblyLine.implied(TAX)) - case LDY => List(AssemblyLine.implied(TAY)) - case LDZ => List(AssemblyLine.implied(TAZ)) - }) + AssemblyLine.label(label)) + if (preserveA) { + opcode match { + case LDA => loadHiByteToA ++ signExtend + case LAX => loadHiByteToA ++ signExtend ++ List(AssemblyLine.implied(TAX)) + case LDX => loadHiByteToA ++ List(AssemblyLine.implied(PHA)) ++ signExtend ++ List(AssemblyLine.implied(TAX), AssemblyLine.implied(PLA)) + case LDY => loadHiByteToA ++ List(AssemblyLine.implied(PHA)) ++ signExtend ++ List(AssemblyLine.implied(TAY), AssemblyLine.implied(PLA)) + case LDZ => loadHiByteToA ++ List(AssemblyLine.implied(PHA)) ++ signExtend ++ List(AssemblyLine.implied(TAZ), AssemblyLine.implied(PLA)) + } + } else { + opcode match { + case LDA => loadHiByteToA ++ signExtend + case LDX | LAX => loadHiByteToA ++ signExtend ++ List(AssemblyLine.implied(TAX)) + case LDY => loadHiByteToA ++ signExtend ++ List(AssemblyLine.implied(TAY)) + case LDZ => loadHiByteToA ++ signExtend ++ List(AssemblyLine.implied(TAZ)) + } + } } else { List(AssemblyLine.immediate(opcode, 0)) } diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 16f1ab36..b53c58d1 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -733,7 +733,14 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { AssemblyLine.implied(PLA)) } else AssemblyLine.variable(ctx, LDA, source) :+ AssemblyLine.immediate(LDX, 0) case 2 => - AssemblyLine.variable(ctx, LDA, source) ++ AssemblyLine.variable(ctx, LDX, source, 1) + val lo = AssemblyLine.variable(ctx, LDA, source) + if (lo.exists(_.concernsX)) { + lo ++ AssemblyLine.variable(ctx, LDX, source, 1, preserveA = true) + } else { + val hi = AssemblyLine.variable(ctx, LDX, source, 1) + if (hi.length == 1 && hi.head.opcode == LDX) lo ++ hi + else hi ++ lo + } } case RegisterVariable(MosRegister.AY, _) => exprType.size match { diff --git a/src/test/scala/millfork/test/SignExtensionSuite.scala b/src/test/scala/millfork/test/SignExtensionSuite.scala index 2c4bd6af..6f29baf9 100644 --- a/src/test/scala/millfork/test/SignExtensionSuite.scala +++ b/src/test/scala/millfork/test/SignExtensionSuite.scala @@ -32,6 +32,21 @@ class SignExtensionSuite extends FunSuite with Matchers { | } """.stripMargin){m => m.readWord(0xc000) should equal(0xfffe)} } + + test("Sbyte to Word 3") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(""" + | import zp_reg + | word output @$c000 + | void main () { + | sbyte x,y + | x = b(1) + | y = b(31) + | output = word(x) * word(y) + | } + | noinline sbyte b(sbyte x) = x + """.stripMargin){m => m.readWord(0xc000) should equal(31)} + } + test("Sbyte to Long") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp, Cpu.Intel8086, Cpu.Motorola6809)(""" | long output @$c000 diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index 50715be2..eebd4bd2 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -762,4 +762,56 @@ class WordMathSuite extends FunSuite with Matchers with AppendedClues { } } } + + test("Sign extension in multiplication") { + for { + x <- Seq(0, -10, 10, 120, -120) + y <- Seq(0, -10, 10, 120, -120) + angle <- Seq(0, 156, 100, 67) + } { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + s""" + | import zp_reg + | array(sbyte) sinTable @$$c100= for i,0,to,256+64-1 [sin(i,127)] + | sbyte outputX @$$c000 + | sbyte outputY @$$c001 + | noinline sbyte rotatePointX(sbyte x,sbyte y,byte angle) { + | sbyte rx + | sbyte s,c + | s = sinTable[angle] + | angle = angle + 64 + | c = sinTable[angle] + | rx = lo(((word(x)*word(c))>>7) - ((word(y)*word(s))>>7)) + | + | return rx; + |} + | + |noinline sbyte rotatePointY(sbyte x,sbyte y,byte angle) { + | sbyte ry + | sbyte s,c + | s = sinTable[angle] + | angle = angle + 64 + | c = sinTable[angle] + | ry = lo(((word(x)*word(s))>>7) + ((word(y)*word(c))>>7)) + | return ry; + |} + | void main () { + | outputX = rotatePointX($x, $y, $angle) + | outputY = rotatePointY($x, $y, $angle) + | } + """. + stripMargin){m => + for (a <- 0 until (256+64)) { + val expected = (127 * math.sin(a * math.Pi / 128)).round.toInt + m.readByte(0xc100 + a).toByte should equal(expected.toByte) withClue s"= sin($a)" + } + val s = (127 * math.sin(angle * math.Pi / 128)).round.toInt + val c = (127 * math.sin((angle + 64) * math.Pi / 128)).round.toInt + val rx = (x * c >> 7) - (y * s >> 7) + val ry = (x * s >> 7) + (y * c >> 7) + m.readByte(0xc000).toByte should equal(rx.toByte) withClue s"= x of ($x,$y) @ $angle" + m.readByte(0xc001).toByte should equal(ry.toByte) withClue s"= y of ($x,$y) @ $angle" + } + } + } }