From f7dd78e4c0021da2116d0fea87516e5b5c17847c Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Tue, 1 Oct 2019 00:45:14 +0200 Subject: [PATCH] Byte comparison optimizations --- .../millfork/compiler/mos/BuiltIns.scala | 15 +++++++ .../compiler/z80/Z80Comparisons.scala | 32 +++++++++++++- .../scala/millfork/test/ComparisonSuite.scala | 43 ++++++++++++++++++- 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 87e8dc82..09fb57ee 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -446,6 +446,21 @@ object BuiltIns { } case _ => } + (maybeConstant, compType) match { + case (Some(NumericConstant(0, _)), ComparisonType.GreaterUnsigned) => + return compileByteComparison(ctx, ComparisonType.NotEqual, lhs, rhs, branches) + case (Some(NumericConstant(0, _)), ComparisonType.LessOrEqualUnsigned) => + return compileByteComparison(ctx, ComparisonType.Equal, lhs, rhs, branches) + case (Some(NumericConstant(1, _)), ComparisonType.LessUnsigned) => + return compileByteComparison(ctx, ComparisonType.Equal, lhs, rhs #-# 1, branches) + case (Some(NumericConstant(1, _)), ComparisonType.GreaterOrEqualUnsigned) => + return compileByteComparison(ctx, ComparisonType.NotEqual, lhs, rhs #-# 1, branches) + case (Some(NumericConstant(n, 1)), ComparisonType.GreaterUnsigned) if n >= 1 && n <= 254 => + return compileByteComparison(ctx, ComparisonType.GreaterOrEqualUnsigned, lhs, rhs #+# 1, branches) + case (Some(NumericConstant(n, 1)), ComparisonType.LessOrEqualUnsigned) if n >= 1 && n <= 254 => + return compileByteComparison(ctx, ComparisonType.LessUnsigned, lhs, rhs #+# 1, branches) + case _ => + } val cmpOp = if (ComparisonType.isSigned(compType)) SBC else CMP var comparingAgainstZero = false val secondParamCompiled0 = maybeConstant match { diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala index 0bd04cfb..f55d4548 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -21,19 +21,36 @@ object Z80Comparisons { FunctionCallExpression("^", List(r, LiteralExpression(0x80, 1).pos(r.position))).pos(r.position), branches) } + (ctx.env.eval(r), compType) match { + case (Some(NumericConstant(0, _)), GreaterUnsigned) => + return compile8BitComparison(ctx, ComparisonType.NotEqual, l, r, branches) + case (Some(NumericConstant(0, _)), LessOrEqualUnsigned) => + return compile8BitComparison(ctx, ComparisonType.Equal, l, r, branches) + case (Some(NumericConstant(1, _)), ComparisonType.LessUnsigned) => + return compile8BitComparison(ctx, ComparisonType.Equal, l, r #-# 1, branches) + case (Some(NumericConstant(1, _)), ComparisonType.GreaterOrEqualUnsigned) => + return compile8BitComparison(ctx, ComparisonType.NotEqual, l, r #-# 1, branches) + case (Some(NumericConstant(n, 1)), ComparisonType.GreaterUnsigned) if n >= 1 && n <= 254 => + return compile8BitComparison(ctx, ComparisonType.GreaterOrEqualUnsigned, l, r #+# 1, branches) + case (Some(NumericConstant(n, 1)), ComparisonType.LessOrEqualUnsigned) if n >= 1 && n <= 254 => + return compile8BitComparison(ctx, ComparisonType.LessUnsigned, l, r #+# 1, branches) + case _ => + } compType match { case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => return compile8BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) case _ => () } - val prepareAE = Z80ExpressionCompiler.compileToA(ctx, r) match { + + var prepareAE = Z80ExpressionCompiler.compileToA(ctx, r) match { case List(ZLine0(ZOpcode.LD, TwoRegisters(ZRegister.A, ZRegister.IMM_8), param)) => Z80ExpressionCompiler.compileToA(ctx, l) :+ ZLine.ldImm8(ZRegister.E, param) case compiledR => compiledR ++ List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ Z80ExpressionCompiler.stashDEIfChanged(ctx, Z80ExpressionCompiler.compileToA(ctx, l)) } - val calculateFlags = if (ComparisonType.isSigned(compType) && ctx.options.flag(CompilationFlag.EmitZ80Opcodes)) { + + var calculateFlags = if (ComparisonType.isSigned(compType) && ctx.options.flag(CompilationFlag.EmitZ80Opcodes)) { val fixup = ctx.nextLabel("co") List( ZLine.register(ZOpcode.SUB, ZRegister.E), @@ -43,6 +60,17 @@ object Z80Comparisons { } else if (ComparisonType.isSigned(compType) && !ctx.options.flag(CompilationFlag.EmitZ80Opcodes)) { List(ZLine.register(ZOpcode.SUB, ZRegister.E)) } else List(ZLine.register(ZOpcode.CP, ZRegister.E)) + + (prepareAE.last, calculateFlags.head) match { + case ( + ZLine0(ZOpcode.LD, TwoRegisters(ZRegister.E, ZRegister.IMM_8), c), + ZLine0(op, OneRegister(ZRegister.E), _) + ) => + prepareAE = prepareAE.init + calculateFlags = ZLine.imm8(op, c) :: calculateFlags.tail + case _ => + } + if (branches == NoBranching) return prepareAE ++ calculateFlags val (effectiveCompType, label) = branches match { case BranchIfFalse(la) => ComparisonType.negate(compType) -> la diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index fb8a7a2d..12a46f1b 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuSuperOptimizedRun, EmuUltraBenchmarkRun} +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuSuperOptimizedRun, EmuUltraBenchmarkRun, EmuUnoptimizedCrossPlatformRun} import org.scalatest.{FunSuite, Matchers} /** @@ -540,4 +540,45 @@ class ComparisonSuite extends FunSuite with Matchers { |""".stripMargin EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Sixteen, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(code) {m => m.readByte(0xc000) should equal (code.count(_ == '↑'))} } + + test("Comparison optimization") { + val code = + """ + | noinline byte one() = 1 + | noinline byte ten() = 10 + | noinline byte ff() = $ff + | noinline byte zero() = 0 + | byte output @$c000 + | void main() { + | output = 0 + | if zero() >= 0 { output += 1 } // ↑ + | if zero() <= 0 { output += 1 } // ↑ + | if one() > 0 { output += 1 } // ↑ + | if one() >= 0 { output += 1 } // ↑ + | if ten() > 0 { output += 1 } // ↑ + | + | if ten() <= $ff { output += 1 } // ↑ + | if ten() < $ff { output += 1 } // ↑ + | if ff() >= $ff { output += 1 } // ↑ + | + | if one() >= 1 { output += 1 } // ↑ + | if one() <= 1 { output += 1 } // ↑ + | if ten() >= 1 { output += 1 } // ↑ + | if ten() > 1 { output += 1 } // ↑ + | if zero() < 1 { output += 1 } // ↑ + | if zero() <= 1 { output += 1 } // ↑ + | + | if one() < 10 { output += 1 } // ↑ + | if one() <= 10 { output += 1 } // ↑ + | if ten() <= 10 { output += 1 } // ↑ + | if ten() >= 10 { output += 1 } // ↑ + | if ff() > 10 { output += 1 } // ↑ + | if ff() >= 10 { output += 1 } // ↑ + | } + |""".stripMargin + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(code) { m => + m.readByte(0xc000) should equal(code.count(_ == '↑')) + } + } + }