diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala index 01830b26..92c3b9db 100644 --- a/src/main/scala/millfork/compiler/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -618,7 +618,41 @@ object BuiltIns { def isRhsStack(xs: List[AssemblyLine]): Boolean = xs.exists(_.opcode == TSX) + val canUseIncDec = !decimal && targetBytes.forall(_.forall(l => l.opcode != STA || (l.addrMode match { + case AddrMode.Absolute => true + case AddrMode.AbsoluteX => true + case AddrMode.ZeroPage => true + case AddrMode.ZeroPageX => true + case _ => false + }))) + + def doDec(lines: List[List[AssemblyLine]]):List[AssemblyLine] = lines match { + case Nil => Nil + case x :: Nil => staTo(DEC, x) + case x :: xs => + val label = MlCompiler.nextLabel("de") + staTo(LDA, x) ++ + List(AssemblyLine.relative(BNE, label)) ++ + doDec(xs) ++ + List(AssemblyLine.label(label)) ++ + staTo(DEC, x) + } + val (calculateRhs, addendByteRead0): (List[AssemblyLine], List[List[AssemblyLine]]) = env.eval(addend) match { + case Some(NumericConstant(0, _)) => + return Nil + case Some(NumericConstant(1, _)) if canUseIncDec && !subtract => + val label = MlCompiler.nextLabel("in") + return staTo(INC, targetBytes.head) ++ targetBytes.tail.flatMap(l => AssemblyLine.relative(BNE, label)::staTo(INC, l)) :+ AssemblyLine.label(label) + case Some(NumericConstant(-1, _)) if canUseIncDec && subtract => + val label = MlCompiler.nextLabel("in") + return staTo(INC, targetBytes.head) ++ targetBytes.tail.flatMap(l => AssemblyLine.relative(BNE, label)::staTo(INC, l)) :+ AssemblyLine.label(label) + case Some(NumericConstant(1, _)) if canUseIncDec && subtract => + val label = MlCompiler.nextLabel("de") + return doDec(targetBytes) + case Some(NumericConstant(-1, _)) if canUseIncDec && !subtract => + val label = MlCompiler.nextLabel("de") + return doDec(targetBytes) case Some(constant) => addendSize = targetSize Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i)))) diff --git a/src/test/scala/millfork/test/LongTest.scala b/src/test/scala/millfork/test/LongTest.scala index 129bb24f..195a599f 100644 --- a/src/test/scala/millfork/test/LongTest.scala +++ b/src/test/scala/millfork/test/LongTest.scala @@ -157,4 +157,44 @@ class LongTest extends FunSuite with Matchers { m.readLong(0xc000) should equal(0x44) } } + + test("Long INC/DEC") { + EmuBenchmarkRun( + """ + | long output0 @$c000 + | long output1 @$c004 + | long output2 @$c008 + | long output3 @$c00c + | long output4 @$c010 + | long output5 @$c014 + | long output6 @$c018 + | void main () { + | output0 = 0 + | output1 = $FF + | output2 = $FFFF + | output3 = $FF00 + | output4 = $FF00 + | output5 = $10000 + | output6 = 0 + | barrier() + | output0 += 1 + | output1 += 1 + | output2 += 1 + | output3 += 1 + | output4 -= 1 + | output5 -= 1 + | output6 -= 1 + | } + | void barrier() { + | } + """.stripMargin) { m => + m.readLong(0xc000) should equal(1) + m.readLong(0xc004) should equal(0x100) + m.readLong(0xc008) should equal(0x10000) + m.readLong(0xc00c) should equal(0xff01) + m.readLong(0xc010) should equal(0xfeff) + m.readLong(0xc014) should equal(0xffff) + m.readLong(0xc018) should equal(0xffffffff) + } + } }