diff --git a/src/main/scala/millfork/compiler/z80/Z80Shifting.scala b/src/main/scala/millfork/compiler/z80/Z80Shifting.scala index 3cfd99dd..ff72a33a 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Shifting.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Shifting.scala @@ -5,7 +5,7 @@ import millfork.assembly.z80.{NoRegisters, ZLine, ZOpcode} import millfork.compiler.CompilationContext import millfork.env.NumericConstant import millfork.error.ConsoleLogger -import millfork.node.{Expression, LhsExpression, ZRegister} +import millfork.node._ import scala.collection.GenTraversableOnce @@ -14,6 +14,10 @@ import scala.collection.GenTraversableOnce */ object Z80Shifting { + private def calculateIterationCountPlus1(ctx: CompilationContext, rhs: Expression) = { + Z80ExpressionCompiler.compile8BitTo(ctx, SumExpression(List(false -> rhs, false -> LiteralExpression(1, 1)), false), ZRegister.B) + } + private def fixAfterShiftIfNeeded(extendedOps: Boolean, left: Boolean, i: Long): List[ZLine] = if (extendedOps || left) { Nil @@ -41,11 +45,12 @@ object Z80Shifting { l ++ List.tabulate(i.toInt)(_ => op) ++ fixAfterShiftIfNeeded(extendedOps, left, i) } case _ => - val calcCount = Z80ExpressionCompiler.compile8BitTo(ctx, rhs, ZRegister.B) + val calcCount = calculateIterationCountPlus1(ctx, rhs) val l = Z80ExpressionCompiler.stashBCIfChanged(ctx, Z80ExpressionCompiler.compileToA(ctx, lhs)) val loopBody = op :: fixAfterShiftIfNeeded(extendedOps, left, 1) - val label = ctx.nextLabel("sh") - calcCount ++ l ++ List(ZLine.label(label)) ++ loopBody ++ ZLine.djnz(ctx, label) + val labelL = ctx.nextLabel("sh") + val labelS = ctx.nextLabel("sh") + calcCount ++ l ++ List(ZLine.jumpR(ctx, labelS), ZLine.label(labelL)) ++ loopBody ++ List(ZLine.label(labelS)) ++ ZLine.djnz(ctx, labelL) } } @@ -117,11 +122,12 @@ object Z80Shifting { } } case _ => - val calcCount = Z80ExpressionCompiler.compile8BitTo(ctx, rhs, ZRegister.B) + val calcCount = calculateIterationCountPlus1(ctx, rhs) val l = Z80ExpressionCompiler.stashBCIfChanged(ctx, Z80ExpressionCompiler.compileToA(ctx, lhs)) val loopBody = ZLine.register(op, ZRegister.A) :: fixAfterShiftIfNeeded(extendedOps, left, 1) - val label = ctx.nextLabel("sh") - calcCount ++ l ++ List(ZLine.label(label)) ++ loopBody ++ ZLine.djnz(ctx, label) ++ Z80ExpressionCompiler.storeA(ctx, lhs, signedSource = false) + val labelL = ctx.nextLabel("sh") + val labelS = ctx.nextLabel("sh") + calcCount ++ l ++ List(ZLine.jumpR(ctx, labelS), ZLine.label(labelL)) ++ loopBody ++ List(ZLine.label(labelS)) ++ ZLine.djnz(ctx, labelL) ++ Z80ExpressionCompiler.storeA(ctx, lhs, signedSource = false) } } @@ -168,7 +174,7 @@ object Z80Shifting { } } case _ => - val calcCount = Z80ExpressionCompiler.compile8BitTo(ctx, rhs, ZRegister.B) + val calcCount = calculateIterationCountPlus1(ctx, rhs) val loopBody = if (extendedOps) { if (left) { @@ -200,8 +206,9 @@ object Z80Shifting { ZLine.ld8(ZRegister.L, ZRegister.A)) } } - val label = ctx.nextLabel("sh") - calcCount ++ l ++ List(ZLine.label(label)) ++ loopBody ++ ZLine.djnz(ctx, label) + val labelS = ctx.nextLabel("sh") + val labelL = ctx.nextLabel("sh") + calcCount ++ l ++ List(ZLine.jumpR(ctx, labelS), ZLine.label(labelL)) ++ loopBody ++ List(ZLine.label(labelS)) ++ ZLine.djnz(ctx, labelL) } } @@ -270,9 +277,10 @@ object Z80Shifting { case Some(NumericConstant(n, _)) => List.fill(n.toInt)(shiftOne).flatten case _ => - val label = ctx.nextLabel("sh") - val calcCount = Z80ExpressionCompiler.compile8BitTo(ctx, rhs, ZRegister.B) - calcCount ++ List(ZLine.label(label)) ++ shiftOne ++ ZLine.djnz(ctx, label) + val calcCount = calculateIterationCountPlus1(ctx, rhs) + val labelL = ctx.nextLabel("sh") + val labelS = ctx.nextLabel("sh") + calcCount ++ List(ZLine.jumpR(ctx, labelS), ZLine.label(labelL)) ++ shiftOne ++ List(ZLine.label(labelS)) ++ ZLine.djnz(ctx, labelL) } } diff --git a/src/test/scala/millfork/test/ShiftSuite.scala b/src/test/scala/millfork/test/ShiftSuite.scala index b36bf9c5..e9cb670d 100644 --- a/src/test/scala/millfork/test/ShiftSuite.scala +++ b/src/test/scala/millfork/test/ShiftSuite.scala @@ -102,4 +102,25 @@ class ShiftSuite extends FunSuite with Matchers { m.readByte(0xc005) should equal(8) } } + + test("Zero shifting") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(""" + | byte output0 @$c000 + | byte output1 @$c001 + | noinline byte sl(byte input, byte amount) { + | return input << amount + | } + | noinline byte sr(byte input, byte amount) { + | return input >> amount + | } + | void main () { + | output0 = sl(42, 0) + | output1 = sr(42, 0) + | } + | noinline byte b(byte x) { return x } + """.stripMargin){m => + m.readByte(0xc000) should equal(42) + m.readByte(0xc001) should equal(42) + } + } }