From 0bb662183c3429251cb609cf73354288845d4f64 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Wed, 26 Dec 2018 02:05:41 +0100 Subject: [PATCH] Faster comparisons to 0 and $FFFF --- .../millfork/compiler/mos/BuiltIns.scala | 35 ++++++++++++---- .../compiler/z80/Z80Comparisons.scala | 30 ++++++++++++-- .../scala/millfork/test/ComparisonSuite.scala | 40 +++++++++++++++++++ 3 files changed, 93 insertions(+), 12 deletions(-) diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 1a6122a6..180ce211 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -631,15 +631,34 @@ object BuiltIns { } val lType = MosExpressionCompiler.getExpressionType(ctx, lhs) val rType = MosExpressionCompiler.getExpressionType(ctx, rhs) - val compactEqualityComparison = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) { - None - } else if (lType.size == 1 && !lType.isSigned) { - Some(cmpTo(LDA, ll) ++ cmpTo(EOR, rl) ++ cmpTo(ORA, rh)) - } else if (rType.size == 1 && !rType.isSigned) { - Some(cmpTo(LDA, rl) ++ cmpTo(EOR, ll) ++ cmpTo(ORA, lh)) - } else { - None + def isConstant(h: List[AssemblyLine], l: List[AssemblyLine], value: Int): Boolean = { + (h,l) match { + case ( + List(AssemblyLine0(CMP, Immediate, NumericConstant(vh, _))), + List(AssemblyLine0(CMP, Immediate, NumericConstant(vl, _))) + ) if vh.&(0xff).<<(8) + vl.&(0xff) == value => true + case _ => false + } } + + val compactEqualityComparison = + if (isConstant(rh, rl, 0)) { + Some(cmpTo(LDA, ll) ++ cmpTo(ORA, lh)) + } else if (isConstant(lh, ll, 0)) { + Some(cmpTo(LDA, rl) ++ cmpTo(ORA, rh)) + } else if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) { + None + } else if (isConstant(rh, rl, 0xffff)) { + Some(cmpTo(LDA, ll) ++ cmpTo(AND, lh) ++ List(AssemblyLine.immediate(CMP, 0xff))) + } else if (isConstant(lh, ll, 0xffff)) { + Some(cmpTo(LDA, rl) ++ cmpTo(AND, rh) ++ List(AssemblyLine.immediate(CMP, 0xff))) + } else if (lType.size == 1 && !lType.isSigned) { + Some(cmpTo(LDA, ll) ++ cmpTo(EOR, rl) ++ cmpTo(ORA, rh)) + } else if (rType.size == 1 && !rType.isSigned) { + Some(cmpTo(LDA, rl) ++ cmpTo(EOR, ll) ++ cmpTo(ORA, lh)) + } else { + None + } effectiveComparisonType match { case ComparisonType.Equal => compactEqualityComparison match { diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala index 6d355479..99b5a775 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -123,8 +123,34 @@ object Z80Comparisons { return compile16BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) case _ => () } + import ZRegister._ + import ZOpcode._ val calculateLeft = Z80ExpressionCompiler.compileToHL(ctx, l) val calculateRight = Z80ExpressionCompiler.compileToHL(ctx, r) + val fastEqualityComparison: Option[List[ZLine]] = (calculateLeft, calculateRight) match { + case (List(ZLine0(LD_16, TwoRegisters(HL, IMM_16), NumericConstant(0, _))), _) => + Some(calculateRight ++ List(ZLine.ld8(A, H), ZLine.register(OR, L))) + case (List(ZLine0(LD_16, TwoRegisters(HL, IMM_16), NumericConstant(0xffff, _))), _) => + Some(calculateRight ++ List(ZLine.ld8(A, H), ZLine.register(AND, L), ZLine.imm8(CP, 0xff))) + case (_, List(ZLine0(LD_16, TwoRegisters(HL, IMM_16), NumericConstant(0, _)))) => + Some(calculateLeft ++ List(ZLine.ld8(A, H), ZLine.register(OR, L))) + case (_, List(ZLine0(LD_16, TwoRegisters(HL, IMM_16), NumericConstant(0xffff, _)))) => + Some(calculateLeft ++ List(ZLine.ld8(A, H), ZLine.register(AND, L), ZLine.imm8(CP, 0xff))) + case _ => + None + } + fastEqualityComparison match { + case Some(code) => + (compType, branches) match { + case (ComparisonType.Equal, BranchIfTrue(lbl)) => return code :+ ZLine.jumpR(ctx, lbl, IfFlagSet(ZFlag.Z)) + case (ComparisonType.Equal, BranchIfFalse(lbl)) => return code :+ ZLine.jumpR(ctx, lbl, IfFlagClear(ZFlag.Z)) + case (ComparisonType.NotEqual, BranchIfTrue(lbl)) => return code :+ ZLine.jumpR(ctx, lbl, IfFlagClear(ZFlag.Z)) + case (ComparisonType.NotEqual, BranchIfFalse(lbl)) => return code :+ ZLine.jumpR(ctx, lbl, IfFlagSet(ZFlag.Z)) + case _ => + } + case _ => + } + val (calculated, useBC) = if (calculateLeft.exists(Z80ExpressionCompiler.changesBC)) { if (calculateLeft.exists(Z80ExpressionCompiler.changesDE)) { calculateRight ++ List(ZLine.register(ZOpcode.PUSH, ZRegister.HL)) ++ Z80ExpressionCompiler.fixTsx(ctx, calculateLeft) ++ List(ZLine.register(ZOpcode.POP, ZRegister.BC)) -> false @@ -152,8 +178,6 @@ object Z80Comparisons { } calculateFlags :+ jump } else if (compType == Equal || compType == NotEqual) { - import ZRegister._ - import ZOpcode._ val calculateFlags = calculated ++ List( ZLine.ld8(A, L), ZLine.register(XOR, if (useBC) C else E), @@ -169,8 +193,6 @@ object Z80Comparisons { } calculateFlags :+ jump } else { - import ZRegister._ - import ZOpcode._ val calculateFlags = calculated ++ List( ZLine.ld8(A, L), ZLine.register(SUB, if (useBC) C else E), diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index 9a836c24..ba8edcdf 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -420,4 +420,44 @@ class ComparisonSuite extends FunSuite with Matchers { m.readByte(0xc002) should equal(0) } } + + test("Compare to $ffff") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | byte output @$c000 + | void main() { + | stuff($ffff) + | barrier() + | } + | noinline void stuff (word x) { + | if x == $ffff { + | output = 11 + | } + | } + | noinline void barrier() {} + """.stripMargin + ) { m => + m.readByte(0xc000) should equal(11) + } + } + + test("Compare to 0") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | byte output @$c000 + | void main() { + | stuff(0) + | barrier() + | } + | noinline void stuff (word x) { + | if x == 0 { + | output = 11 + | } + | } + | noinline void barrier() {} + """.stripMargin + ) { m => + m.readByte(0xc000) should equal(11) + } + } }