From 869211658affc3240688bf306bfdecdb8f26aab3 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 24 Jun 2019 00:13:16 +0200 Subject: [PATCH] Division optimizations --- .../opt/RuleBasedAssemblyOptimization.scala | 29 ++++++++ .../opt/ZeropageRegisterOptimizations.scala | 72 +++++++++++++++++++ .../z80/opt/AlwaysGoodI80Optimizations.scala | 47 ++++++++++++ .../scala/millfork/test/ByteMathSuite.scala | 44 ++++++++++++ .../scala/millfork/test/WordMathSuite.scala | 49 +++++++++++++ 5 files changed, 241 insertions(+) diff --git a/src/main/scala/millfork/assembly/mos/opt/RuleBasedAssemblyOptimization.scala b/src/main/scala/millfork/assembly/mos/opt/RuleBasedAssemblyOptimization.scala index 04130854..acf8e7d3 100644 --- a/src/main/scala/millfork/assembly/mos/opt/RuleBasedAssemblyOptimization.scala +++ b/src/main/scala/millfork/assembly/mos/opt/RuleBasedAssemblyOptimization.scala @@ -552,6 +552,35 @@ case class MatchA(i: Int) extends AssemblyLinePattern { override def hitRate: Double = 0.42 } +case class MatchStoredRegister(i: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = { + import Opcode._ + line.opcode match { + case STA => + flowInfo.statusBefore.a match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } + case STX => + flowInfo.statusBefore.x match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } + case STY => + flowInfo.statusBefore.y match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } + case _ => false + } + } + + override def hitRate: Double = 0.42 +} + case class MatchX(i: Int) extends AssemblyLinePattern { override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = FlowInfoRequirement.assertForward(needsFlowInfo) diff --git a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala index fc1d6bb6..224632c0 100644 --- a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala +++ b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala @@ -5,6 +5,7 @@ import millfork.assembly.mos.AddrMode._ import millfork.assembly.AssemblyOptimization import millfork.assembly.mos.{AssemblyLine, AssemblyLine0, Opcode, State} import millfork.DecimalUtils.asDecimal +import millfork.assembly.z80.opt.HasRegister import millfork.error.FatalErrorReporting /** * @author Karol Stasiak @@ -115,6 +116,76 @@ object ZeropageRegisterOptimizations { }, ) + val ConstantDivision = new RuleBasedAssemblyOptimization("Constant division", + needsFlowInfo = FlowInfoRequirement.BothFlows, + // TODO: constants other than power of 2: + + (HasOpcodeIn(STA, STX, STY) & RefersTo("__reg", 1) & MatchStoredRegister(2) & MatchAddrMode(0) & MatchParameter(1)) ~ + Where({ ctx => + val a = ctx.get[Int](2) + a != 0 && a.-(1).&(a) == 0 + }) ~ + (Linear & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(JSR) & RefersTo("__div_u8u8u8u8", 0) + & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V) // everything else (including Y) should be preserved + & DoesntMatterWhatItDoesWithReg(0) + & DoesntMatterWhatItDoesWithReg(1)) ~~> { (code, ctx) => + val count = Integer.numberOfTrailingZeros(ctx.get[Int](2)) + val zreg = ctx.zeropageRegister.get + code.init ++ List(AssemblyLine.zeropage(LDA, zreg)) ++ List.fill(count)(AssemblyLine.implied(LSR)) + }, + + (HasOpcodeIn(STA, STX, STY) & RefersTo("__reg", 1) & MatchStoredRegister(2) & MatchAddrMode(0) & MatchParameter(1)) ~ + Where({ ctx => + val a = ctx.get[Int](2) + a != 0 && a.-(1).&(a) == 0 + }) ~ + (Linear & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(JSR) & RefersTo("__mod_u8u8u8u8", 0) + & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V, State.X) // everything else (including Y) should be preserved + & DoesntMatterWhatItDoesWithReg(0) + & DoesntMatterWhatItDoesWithReg(1)) ~~> { (code, ctx) => + val a = ctx.get[Int](2) + val zreg = ctx.zeropageRegister.get + code.init ++ List(AssemblyLine.zeropage(LDA, zreg), AssemblyLine.immediate(AND, a - 1)) + }, + + (HasOpcodeIn(STA, STX, STY) & RefersTo("__reg", 2) & MatchStoredRegister(2) & MatchAddrMode(0) & MatchParameter(1)) ~ + Where({ ctx => + val a = ctx.get[Int](2) + a != 0 && a.-(1).&(a) == 0 + }) ~ + (Linear & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(JSR) & RefersTo("__div_u16u8u16u8", 0) + & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V) // everything else (including Y) should be preserved + & DoesntMatterWhatItDoesWithReg(0) + & DoesntMatterWhatItDoesWithReg(1) + & DoesntMatterWhatItDoesWithReg(2)) ~~> { (code, ctx) => + val count = Integer.numberOfTrailingZeros(ctx.get[Int](2)) + val zreg = ctx.zeropageRegister.get + code.init ++ + List.fill(count)(List(AssemblyLine.zeropage(LSR, zreg, 1), AssemblyLine.zeropage(ROR, zreg))).flatten ++ + List(AssemblyLine.zeropage(LDA, zreg), AssemblyLine.zeropage(LDX, zreg, 1)) + }, + + (HasOpcodeIn(STA, STX, STY) & RefersTo("__reg", 2) & MatchStoredRegister(2) & MatchAddrMode(0) & MatchParameter(1)) ~ + Where({ ctx => + val a = ctx.get[Int](2) + a != 0 && a.-(1).&(a) == 0 && a <= 128 + }) ~ + (Linear & DoesntChangeMemoryAt(0, 1)).* ~ + (Elidable & HasOpcode(JSR) & RefersTo("__mod_u16u8u16u8", 0) + & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V, State.X) // everything else (including Y) should be preserved + & DoesntMatterWhatItDoesWithReg(0) + & DoesntMatterWhatItDoesWithReg(1) + & DoesntMatterWhatItDoesWithReg(2)) ~~> { (code, ctx) => + val a = ctx.get[Int](2) + val zreg = ctx.zeropageRegister.get + code.init ++ List(AssemblyLine.zeropage(LDA, zreg), AssemblyLine.immediate(AND, a - 1), AssemblyLine.immediate(LDX, 0)) + }, + + ) + val ConstantDecimalMath = new RuleBasedAssemblyOptimization("Constant decimal math", needsFlowInfo = FlowInfoRequirement.BothFlows, @@ -400,6 +471,7 @@ object ZeropageRegisterOptimizations { val All: List[AssemblyOptimization[AssemblyLine]] = List( ConstantDecimalMath, + ConstantDivision, ConstantMultiplication, ConstantInlinedMultiplication, LoadingKnownValue, diff --git a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala index 7caa0005..cc7cf511 100644 --- a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala +++ b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodI80Optimizations.scala @@ -6,6 +6,7 @@ import millfork.assembly.z80.ZOpcode._ import millfork.env.{CompoundConstant, Constant, InitializedArray, MathOperator, MemoryAddressConstant, NumericConstant} import millfork.node.{LiteralExpression, ZRegister} import ZRegister._ +import millfork.CompilationFlag import millfork.DecimalUtils._ import millfork.error.FatalErrorReporting @@ -1309,6 +1310,51 @@ object AlwaysGoodI80Optimizations { ) + val ConstantDivision = new RuleBasedAssemblyOptimization("Constant division", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(CALL) + & IsUnconditional + & RefersTo("__divmod_u16u8u16u8", 0) + & MatchRegister(ZRegister.H, 4) + & MatchRegister(ZRegister.L, 5) + & MatchRegister(ZRegister.D, 6) + & DoesntMatterWhatItDoesWithFlags + & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C, ZRegister.B)) ~~> { (_, ctx) => + val p = ctx.get[Int](4) * 256 + ctx.get[Int](5).&(0xff) + val q = ctx.get[Int](6) + if (q == 0) Nil // lol undefined behaviour, everyone's favourite C feature + else List(ZLine.ldImm16(ZRegister.HL, p / q), ZLine.ldImm8(ZRegister.A, p % q)) + }, + (Elidable & HasOpcode(CALL) + & IsUnconditional + & RefersTo("__divmod_u16u8u16u8", 0) + & MatchRegister(ZRegister.D, 6) + & DoesntMatterWhatItDoesWithFlags + & DoesntMatterWhatItDoesWith(ZRegister.D, ZRegister.E, ZRegister.C, ZRegister.B)) ~ + Where(ctx => { + val q = ctx.get[Int](6) + q != 0 && q.&(q - 1) == 0 + }) ~~> { (_, ctx) => + val q = ctx.get[Int](6) + if (ctx.compilationOptions.flag(CompilationFlag.EmitExtended80Opcodes)) { + List(ZLine.ld8(ZRegister.A, ZRegister.L), ZLine.imm8(ZOpcode.AND, q - 1)) ++ (0L until Integer.numberOfTrailingZeros(q)).flatMap(_ => List( + ZLine.register(ZOpcode.SRL, ZRegister.H), + ZLine.register(ZOpcode.RR, ZRegister.L) + )) + } else { + List(ZLine.ld8(ZRegister.D, ZRegister.L)) ++ (0L until Integer.numberOfTrailingZeros(q)).flatMap(_ => List( + ZLine.ld8(ZRegister.A, ZRegister.H), + ZLine.register(ZOpcode.OR, ZRegister.A), + ZLine.implied(ZOpcode.RRA), + ZLine.ld8(ZRegister.H, ZRegister.A), + ZLine.ld8(ZRegister.A, ZRegister.L), + ZLine.implied(ZOpcode.RRA), + ZLine.ld8(ZRegister.L, ZRegister.A) + )) ++ List(ZLine.ld8(ZRegister.A, ZRegister.D), ZLine.imm8(ZOpcode.AND, q - 1)) + } + }, + ) + private def compileMultiply[T](multiplicand: Int, add1:List[T], asl: List[T]): List[T] = { if (multiplicand == 0) FatalErrorReporting.reportFlyingPig("Trying to optimize multiplication by 0 in a wrong way!") def impl(m: Int): List[List[T]] = { @@ -1537,6 +1583,7 @@ object AlwaysGoodI80Optimizations { val All: List[AssemblyOptimization[ZLine]] = List[AssemblyOptimization[ZLine]]( BranchInPlaceRemoval, + ConstantDivision, ConstantMultiplication, ConstantInlinedShifting, FreeHL, diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index b38001ee..1d8ead24 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -390,4 +390,48 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues { divisionCase1(42, 128) divisionCase1(142, 128) } + + test("Byte division 4") { + divisionCase4(0, 2) + divisionCase4(1, 2) + divisionCase4(2, 2) + divisionCase4(250, 128) + divisionCase4(0, 4) + divisionCase4(0, 8) + divisionCase4(1, 4) + divisionCase4(6, 8) + divisionCase4(73, 16) + divisionCase4(75, 128) + divisionCase4(42, 128) + divisionCase4(142, 128) + } + + private def divisionCase4(x: Int, y: Int): Unit = { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp, Cpu.Intel8086)( + s""" + | import zp_reg + | byte output_q1 @$$c000 + | byte output_m1 @$$c001 + | byte output_q2 @$$c002 + | byte output_m2 @$$c003 + | void main () { + | byte a + | output_q2 = g() + | output_m2 = g() + | a = f() + | output_q1 = $x / a + | output_m1 = $x %% a + | output_q2 /= a + | output_m2 %%= a + | } + | byte f() {return $y} + | noinline byte g() {return $x} + """. + stripMargin) { m => + m.readByte(0xc000) should equal(x / y) withClue s"$x / $y" + m.readByte(0xc001) should equal(x % y) withClue s"$x %% $y" + m.readByte(0xc002) should equal(x / y) withClue s"$x / $y" + m.readByte(0xc003) should equal(x % y) withClue s"$x %% $y" + } + } } diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index 2a7dfbe4..41ff1470 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -538,4 +538,53 @@ class WordMathSuite extends FunSuite with Matchers with AppendedClues { m.readWord(0xc006) should equal(x % y) withClue s"= $x %% $y" } } + + test("Word division 4") { + divisionCase4(0, 2) + divisionCase4(1, 2) + divisionCase4(2, 2) + divisionCase4(250, 128) + divisionCase4(0, 4) + divisionCase4(0, 8) + divisionCase4(1, 4) + divisionCase4(6, 8) + divisionCase4(73, 16) + divisionCase4(75, 128) + divisionCase4(42, 128) + divisionCase4(142, 128) + divisionCase2(2534, 2) + divisionCase2(2534, 32) + divisionCase2(35000, 2) + divisionCase2(51462, 4) + divisionCase2(51462, 1) + } + + private def divisionCase4(x: Int, y: Int): Unit = { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp, Cpu.Intel8086)( + s""" + | import zp_reg + | word output_q1 @$$c000 + | byte output_m1 @$$c002 + | word output_q2 @$$c004 + | word output_m2 @$$c006 + | void main () { + | byte a + | output_q2 = g() + | output_m2 = g() + | a = f() + | output_q1 = $x / a + | output_m1 = $x %% a + | output_q2 /= a + | output_m2 %%= a + | } + | byte f() {return $y} + | noinline word g() {return $x} + """. + stripMargin) { m => + m.readWord(0xc000) should equal(x / y) withClue s"$x / $y" + m.readByte(0xc002) should equal(x % y) withClue s"$x %% $y" + m.readWord(0xc004) should equal(x / y) withClue s"$x / $y" + m.readByte(0xc006) should equal(x % y) withClue s"$x %% $y" + } + } }