diff --git a/src/main/scala/millfork/OptimizationPresets.scala b/src/main/scala/millfork/OptimizationPresets.scala index 6983429d..3a7ba776 100644 --- a/src/main/scala/millfork/OptimizationPresets.scala +++ b/src/main/scala/millfork/OptimizationPresets.scala @@ -113,10 +113,12 @@ object OptimizationPresets { val Good: List[AssemblyOptimization] = List[AssemblyOptimization]( AlwaysGoodOptimizations.Adc0Optimization, + AlwaysGoodOptimizations.BitPackingUnpacking, + AlwaysGoodOptimizations.BranchInPlaceRemoval, AlwaysGoodOptimizations.CarryFlagConversion, DangerousOptimizations.ConstantIndexOffsetPropagation, - AlwaysGoodOptimizations.BranchInPlaceRemoval, AlwaysGoodOptimizations.CommonBranchBodyOptimization, + AlwaysGoodOptimizations.CommonExpressionInConditional, AlwaysGoodOptimizations.ConstantFlowAnalysis, AlwaysGoodOptimizations.ConstantIndexPropagation, EmptyMemoryStoreRemoval, diff --git a/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala index 34095b56..c89124a5 100644 --- a/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala +++ b/src/main/scala/millfork/assembly/opt/AlwaysGoodOptimizations.scala @@ -1,12 +1,11 @@ package millfork.assembly.opt -import java.util.UUID import java.util.concurrent.atomic.AtomicInteger -import millfork.assembly.{opt, _} -import millfork.assembly.Opcode._ import millfork.assembly.AddrMode._ +import millfork.assembly.Opcode._ import millfork.assembly.OpcodeClasses._ +import millfork.assembly._ import millfork.env._ /** @@ -823,15 +822,15 @@ object AlwaysGoodOptimizations { (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil), (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ - Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INY))), + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0) + 1)) ~~> (_ => List(AssemblyLine.implied(INY))), (Elidable & HasOpcode(LDY) & MatchImmediate(1) & MatchY(0)) ~ - Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEY))), + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0) - 1)) ~~> (_ => List(AssemblyLine.implied(DEY))), (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0))) ~~> (_ => Nil), (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ - Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)+1)) ~~> (_ => List(AssemblyLine.implied(INX))), + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0) + 1)) ~~> (_ => List(AssemblyLine.implied(INX))), (Elidable & HasOpcode(LDX) & MatchImmediate(1) & MatchX(0)) ~ - Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0)-1)) ~~> (_ => List(AssemblyLine.implied(DEX))), + Where(ctx => ctx.get[Constant](1).quickSimplify.isLowestByteAlwaysEqual(ctx.get[Int](0) - 1)) ~~> (_ => List(AssemblyLine.implied(DEX))), ) val CommonBranchBodyOptimization = new RuleBasedAssemblyOptimization("Common branch body optimization", @@ -858,31 +857,120 @@ object AlwaysGoodOptimizations { (Elidable & HasOpcode(ADC) & HasClear(State.C) & HasClear(State.D) & MatchImmediate(1)) ~ HasOpcode(ASL).+.capture(2) ~ (Elidable & HasOpcode(CLC)) ~ - (Elidable & HasOpcode(ADC) & HasClear(State.D) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> {(code, ctx) => + (Elidable & HasOpcode(ADC) & HasClear(State.D) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> { (code, ctx) => val shifts = ctx.get[List[AssemblyLine]](2) val const = ctx.get[Constant](1).asl(shifts.length) + ctx.get[Constant](3) shifts ++ List(AssemblyLine.implied(CLC), AssemblyLine.immediate(ADC, const)) }, - (Elidable & HasOpcode(AND) & MatchImmediate(1)) ~ + (Elidable & HasOpcode(AND) & MatchImmediate(1)) ~ HasOpcode(ASL).+.capture(2) ~ - (Elidable & HasOpcode(AND) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> {(code, ctx) => + (Elidable & HasOpcode(AND) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> { (code, ctx) => val shifts = ctx.get[List[AssemblyLine]](2) val const = CompoundConstant(MathOperator.And, ctx.get[Constant](1).asl(shifts.length), ctx.get[Constant](3)).quickSimplify shifts :+ AssemblyLine.immediate(AND, const) }, - (Elidable & HasOpcode(EOR) & MatchImmediate(1)) ~ + (Elidable & HasOpcode(EOR) & MatchImmediate(1)) ~ HasOpcode(ASL).+.capture(2) ~ - (Elidable & HasOpcode(EOR) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> {(code, ctx) => + (Elidable & HasOpcode(EOR) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> { (code, ctx) => val shifts = ctx.get[List[AssemblyLine]](2) val const = CompoundConstant(MathOperator.Exor, ctx.get[Constant](1).asl(shifts.length), ctx.get[Constant](3)).quickSimplify shifts :+ AssemblyLine.immediate(EOR, const) }, - (Elidable & HasOpcode(ORA) & MatchImmediate(1)) ~ + (Elidable & HasOpcode(ORA) & MatchImmediate(1)) ~ HasOpcode(ASL).+.capture(2) ~ - (Elidable & HasOpcode(ORA) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> {(code, ctx) => + (Elidable & HasOpcode(ORA) & MatchImmediate(3) & DoesntMatterWhatItDoesWith(State.C, State.Z, State.N)) ~~> { (code, ctx) => val shifts = ctx.get[List[AssemblyLine]](2) val const = CompoundConstant(MathOperator.Or, ctx.get[Constant](1).asl(shifts.length), ctx.get[Constant](3)).quickSimplify shifts :+ AssemblyLine.immediate(ORA, const) }, ) + + val BitPackingUnpacking = new RuleBasedAssemblyOptimization("Bit packing/unpacking", + needsFlowInfo = FlowInfoRequirement.BothFlows, + (Elidable & HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(AND) & HasImmediate(1)) ~ + ((Elidable & Linear & Not(ChangesMemory) & DoesNotConcernMemoryAt(0, 1) & Not(ChangesA)).* ~ + (Elidable & HasOpcode(STA) & DoesNotConcernMemoryAt(0, 1))).capture(3) ~ + ((Elidable & HasOpcodeIn(Set(LSR, ROR)) & Not(ChangesA) & MatchAddrMode(0) & Not(MatchParameter(1))).* ~ + (Elidable & HasOpcodeIn(Set(LSR, ROR)) & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.Z, State.C, State.N))).capture(2) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) ++ + List(AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL)) ++ + ctx.get[List[AssemblyLine]](3) + }, + (Elidable & HasOpcode(LDA) & HasImmediate(1)) ~ + (Elidable & HasOpcode(AND) & MatchAddrMode(0) & MatchParameter(1)) ~ + ((Elidable & Linear & Not(ChangesMemory) & DoesNotConcernMemoryAt(0, 1) & Not(ChangesA)).* ~ + (Elidable & HasOpcode(STA) & DoesNotConcernMemoryAt(0, 1))).capture(3) ~ + ((Elidable & HasOpcodeIn(Set(LSR, ROR)) & Not(ChangesA) & MatchAddrMode(0) & Not(MatchParameter(1))).* ~ + (Elidable & HasOpcodeIn(Set(LSR, ROR)) & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.Z, State.C, State.N))).capture(2) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](2) ++ + List(AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(ROL)) ++ + ctx.get[List[AssemblyLine]](3) + }, + (Elidable & (HasOpcode(ASL) | HasOpcode(ROL) & HasClear(State.C)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(ROL) & Not(ChangesA) & MatchAddrMode(0) & Not(MatchParameter(1))).*.capture(2) ~ + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & HasOpcodeIn(Set(LDA, TYA, TXA, PLA))).capture(3) ~ + (Elidable & HasOpcode(AND) & HasImmediate(1)) ~ + (Elidable & HasOpcode(CLC)).? ~ + (Elidable & (HasOpcode(ORA) | HasOpcode(ADC) & HasClear(State.C) & HasClear(State.D)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(STA) & MatchAddrMode(0) & MatchParameter(1) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Z, State.V, State.A)) ~~> { (code, ctx) => + ctx.get[List[AssemblyLine]](3) ++ + List(AssemblyLine.implied(ROR), code.head.copy(opcode = ROL)) ++ + ctx.get[List[AssemblyLine]](2) + }, + ) + + private def blockIsIdempotentWhenItComesToIndexRegisters(i: Int) = Where(ctx => { + val code = ctx.get[List[AssemblyLine]](i) + val rx = code.indexWhere(ReadsX) + val wx = code.indexWhere(l => ChangesX(l.opcode)) + val ry = code.indexWhere(ReadsY) + val wy = code.indexWhere(l => ChangesY(l.opcode)) + val xOk = rx < 0 || wx < 0 || rx >= wx + val yOk = ry < 0 || wy < 0 || ry >= wy + xOk && yOk + }) + + val CommonExpressionInConditional = new RuleBasedAssemblyOptimization("Common expression in conditional", + needsFlowInfo = FlowInfoRequirement.BackwardFlow, + ( + (HasOpcodeIn(Set(LDA, LAX)) & MatchAddrMode(0) & MatchParameter(1)) ~ + HasOpcodeIn(Set(LDY, LDX, AND, ORA, EOR, ADC, SBC, CLC, SEC, CPY, CPX, CMP)).* + ).capture(7) ~ + blockIsIdempotentWhenItComesToIndexRegisters(7) ~ + HasOpcodeIn(ShortConditionalBranching) ~ + MatchElidableCopyOf(7, Anything, DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V)) ~~> { code => + code.take(code.length / 2 + 1) + }, + + ( + ( + (HasOpcodeIn(Set(LDA, LAX)) & MatchAddrMode(0) & MatchParameter(1)) ~ + HasOpcodeIn(Set(LDY, LDX, AND, ORA, EOR, ADC, SBC, CLC, SEC, CPY, CPX, CMP)).* + ).capture(7) ~ + blockIsIdempotentWhenItComesToIndexRegisters(7) ~ + (HasOpcodeIn(ShortConditionalBranching) & MatchParameter(2)) ~ + Not(HasOpcode(LABEL) & MatchParameter(2)).* ~ + (HasOpcode(LABEL) & MatchParameter(2)) + ).capture(3) ~ + MatchElidableCopyOf(7, Anything, DoesntMatterWhatItDoesWith(State.C, State.Z, State.N, State.V)) ~~> { (_, ctx) => + ctx.get[List[AssemblyLine]](3) + }, + + (Elidable & HasOpcodeIn(Set(LDA, LAX)) & MatchAddrMode(0) & MatchParameter(1)) ~ + (Elidable & HasOpcode(AND) & HasAddrModeIn(Set(Absolute, ZeroPage)) & DoesntMatterWhatItDoesWith(State.C, State.V, State.A)) ~ + HasOpcodeIn(Set(BEQ, BNE)) ~ + (HasOpcodeIn(Set(LDA, LAX)) & MatchAddrMode(0) & MatchParameter(1)) ~~> { code => + List(code(0), code(1).copy(opcode = BIT), code(2)) + }, + + (Elidable & HasOpcode(LDA) & HasAddrModeIn(Set(Absolute, ZeroPage))) ~ + (Elidable & HasOpcode(AND) & MatchAddrMode(0) & MatchParameter(1)) ~ + HasOpcodeIn(Set(BEQ, BNE)) ~ + (HasOpcode(LDA) & MatchAddrMode(0) & MatchParameter(1)) ~~> { code => + List(code(1).copy(opcode = LDA), code(0).copy(opcode = BIT), code(2)) + }, + + ) } diff --git a/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala index 97878ef1..bc4e372c 100644 --- a/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala +++ b/src/main/scala/millfork/assembly/opt/RuleBasedAssemblyOptimization.scala @@ -383,16 +383,16 @@ trait AssemblyLinePattern extends AssemblyPattern { handleKnownDistance((-distance).toShort) case (CompoundConstant(MathOperator.Minus, a, NumericConstant(distance, _)), b) if a.quickSimplify == b.quickSimplify => handleKnownDistance(distance.toShort) - case (MemoryAddressConstant(a: ThingInMemory), MemoryAddressConstant(b:ThingInMemory)) => + case (MemoryAddressConstant(a: ThingInMemory), MemoryAddressConstant(b: ThingInMemory)) => a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)), - MemoryAddressConstant(b: ThingInMemory)) => + MemoryAddressConstant(b: ThingInMemory)) => a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? case (MemoryAddressConstant(a: ThingInMemory), - CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => + CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? case (CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(a: ThingInMemory), NumericConstant(_, _)), - CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => + CompoundConstant(MathOperator.Plus | MathOperator.Minus, MemoryAddressConstant(b: ThingInMemory), NumericConstant(_, _))) => a.name.takeWhile(_ != '.') != b.name.takeWhile(_ != '.') // TODO: ??? case _ => false @@ -400,6 +400,10 @@ trait AssemblyLinePattern extends AssemblyPattern { } } +trait TrivialAssemblyLinePattern extends AssemblyLinePattern with (AssemblyLine => Boolean) { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = this (line) +} + //noinspection LanguageFeature object AssemblyLinePattern { implicit def __implicitOpcodeIn(ops: Set[Opcode.Value]): AssemblyLinePattern = HasOpcodeIn(ops) @@ -488,9 +492,8 @@ case class HasClear(state: State.Value) extends AssemblyLinePattern { flowInfo.hasClear(state) } -case object Anything extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = - true +case object Anything extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = true } case class Not(inner: AssemblyLinePattern) extends AssemblyLinePattern { @@ -536,23 +539,23 @@ case object Linear extends AssemblyLinePattern { OpcodeClasses.AllLinear(line.opcode) } -case object LinearOrBranch extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object LinearOrBranch extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.AllLinear(line.opcode) || OpcodeClasses.ShortBranching(line.opcode) } -case object LinearOrLabel extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object LinearOrLabel extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = line.opcode == Opcode.LABEL || OpcodeClasses.AllLinear(line.opcode) } -case object ReadsA extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ReadsA extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ReadsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ReadsAIfImplied(line.opcode) } -case object ReadsMemory extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ReadsMemory extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = line.addrMode match { case AddrMode.Indirect => true case AddrMode.Implied | AddrMode.Immediate => false @@ -561,51 +564,51 @@ case object ReadsMemory extends AssemblyLinePattern { } } -case object ReadsX extends AssemblyLinePattern { +case object ReadsX extends TrivialAssemblyLinePattern { val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX, AddrMode.AbsoluteIndexedX) - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ReadsXAlways(line.opcode) || XAddrModes(line.addrMode) } -case object ReadsY extends AssemblyLinePattern { +case object ReadsY extends TrivialAssemblyLinePattern { val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY) - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ReadsYAlways(line.opcode) || YAddrModes(line.addrMode) } -case object ConcernsC extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ConcernsC extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ReadsC(line.opcode) && OpcodeClasses.ChangesC(line.opcode) } -case object ConcernsA extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ConcernsA extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ConcernsAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ConcernsAIfImplied(line.opcode) } -case object ConcernsX extends AssemblyLinePattern { +case object ConcernsX extends TrivialAssemblyLinePattern { val XAddrModes = Set(AddrMode.AbsoluteX, AddrMode.IndexedX, AddrMode.ZeroPageX) - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ConcernsXAlways(line.opcode) || XAddrModes(line.addrMode) } -case object ConcernsY extends AssemblyLinePattern { +case object ConcernsY extends TrivialAssemblyLinePattern { val YAddrModes = Set(AddrMode.AbsoluteY, AddrMode.IndexedY, AddrMode.ZeroPageY) - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ConcernsYAlways(line.opcode) || YAddrModes(line.addrMode) } -case object ChangesA extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ChangesA extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ChangesAAlways(line.opcode) || line.addrMode == AddrMode.Implied && OpcodeClasses.ChangesAIfImplied(line.opcode) } -case object ChangesMemory extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case object ChangesMemory extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = OpcodeClasses.ChangesMemoryAlways(line.opcode) || line.addrMode != AddrMode.Implied && OpcodeClasses.ChangesMemoryIfNotImplied(line.opcode) } @@ -620,9 +623,9 @@ case class DoesntChangeMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLin } } -case object ConcernsMemory extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = - ReadsMemory.matchLineTo(ctx, flowInfo, line) && ChangesMemory.matchLineTo(ctx, flowInfo, line) +case object ConcernsMemory extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = + ReadsMemory(line) && ChangesMemory(line) } case class DoesNotConcernMemoryAt(addrMode1: Int, param1: Int) extends AssemblyLinePattern { @@ -635,36 +638,36 @@ case class DoesNotConcernMemoryAt(addrMode1: Int, param1: Int) extends AssemblyL } } -case class HasOpcode(op: Opcode.Value) extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case class HasOpcode(op: Opcode.Value) extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = line.opcode == op override def toString: String = op.toString } -case class HasOpcodeIn(ops: Set[Opcode.Value]) extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case class HasOpcodeIn(ops: Set[Opcode.Value]) extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = ops(line.opcode) override def toString: String = ops.mkString("{", ",", "}") } -case class HasAddrMode(am: AddrMode.Value) extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case class HasAddrMode(am: AddrMode.Value) extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = line.addrMode == am override def toString: String = am.toString } -case class HasAddrModeIn(ams: Set[AddrMode.Value]) extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case class HasAddrModeIn(ams: Set[AddrMode.Value]) extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = ams(line.addrMode) override def toString: String = ams.mkString("{", ",", "}") } -case class HasImmediate(i: Int) extends AssemblyLinePattern { - override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: AssemblyLine): Boolean = +case class HasImmediate(i: Int) extends TrivialAssemblyLinePattern { + override def apply(line: AssemblyLine): Boolean = line.addrMode == AddrMode.Immediate && (line.parameter.quickSimplify match { case NumericConstant(j, _) => (i & 0xff) == (j & 0xff) case _ => false @@ -751,4 +754,22 @@ case class HasCallerCount(count: Int) extends AssemblyLinePattern { case AssemblyLine(Opcode.LABEL, _, MemoryAddressConstant(Label(l)), _) => flowInfo.labelUseCount(l) == count case _ => false } +} + +case class MatchElidableCopyOf(i: Int, firstLinePattern: AssemblyLinePattern, lastLinePattern: AssemblyLinePattern) extends AssemblyPattern { + override def matchTo(ctx: AssemblyMatchingContext, code: List[(FlowInfo, AssemblyLine)]): Option[List[(FlowInfo, AssemblyLine)]] = { + val pattern = ctx.get[List[AssemblyLine]](i) + if (code.length < pattern.length) return None + val (before, after) = code.splitAt(pattern.length) + val lastIndex = code.length - 1 + for (((a, (f, b)), ix) <- pattern.zip(before).zipWithIndex) { + if (!b.elidable) return None + if (a.opcode != b.opcode) return None + if (a.addrMode != b.addrMode) return None + if (a.parameter.quickSimplify != b.parameter.quickSimplify) return None + if (ix == 0 && !firstLinePattern.matchLineTo(ctx, f, b)) return None + if (ix == lastIndex && !lastLinePattern.matchLineTo(ctx, f, b)) return None + } + Some(after) + } } \ No newline at end of file diff --git a/src/test/scala/millfork/test/BitPackingSuite.scala b/src/test/scala/millfork/test/BitPackingSuite.scala new file mode 100644 index 00000000..ea08f0e4 --- /dev/null +++ b/src/test/scala/millfork/test/BitPackingSuite.scala @@ -0,0 +1,129 @@ +package millfork.test + +import millfork.error.ErrorReporting +import millfork.test.emu.EmuBenchmarkRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class BitPackingSuite extends FunSuite with Matchers { + + test("Unpack bits from a byte") { + EmuBenchmarkRun(""" + | array output[8] + | word output_addr @$c000 + | void main () { + | byte b + | output_addr = output.addr + | b = $56 + | barrier() + | byte i + | for i,0,until,8 { + | output[i] = b & 1 + | b >>= 1 + | } + | } + | void barrier() {} + """.stripMargin){m => + val addr = m.readWord(0xc000) + m.readByte(addr) should equal(0) + m.readByte(addr + 1) should equal(1) + m.readByte(addr + 2) should equal(1) + m.readByte(addr + 3) should equal(0) + m.readByte(addr + 4) should equal(1) + m.readByte(addr + 5) should equal(0) + m.readByte(addr + 6) should equal(1) + m.readByte(addr + 7) should equal(0) + } + } + + test("Unpack bits from a word") { + EmuBenchmarkRun(""" + | array output[16] + | word output_addr @$c000 + | void main () { + | word w + | output_addr = output.addr + | w = $CC56 + | barrier() + | byte i + | for i,0,until,16 { + | output[i] = w.lo & 1 + | w >>= 1 + | } + | } + | void barrier() {} + """.stripMargin){m => + val addr = m.readWord(0xc000) + m.readByte(addr) should equal(0) + m.readByte(addr + 1) should equal(1) + m.readByte(addr + 2) should equal(1) + m.readByte(addr + 3) should equal(0) + m.readByte(addr + 4) should equal(1) + m.readByte(addr + 5) should equal(0) + m.readByte(addr + 6) should equal(1) + m.readByte(addr + 7) should equal(0) + m.readByte(addr + 8) should equal(0) + m.readByte(addr + 9) should equal(0) + m.readByte(addr + 10) should equal(1) + m.readByte(addr + 11) should equal(1) + m.readByte(addr + 12) should equal(0) + m.readByte(addr + 13) should equal(0) + m.readByte(addr + 14) should equal(1) + m.readByte(addr + 15) should equal(1) + } + } + + test("Pack bits into byte") { + EmuBenchmarkRun(""" + | byte output @$C000 + | array input = [$F0, 1, 0, $41, $10, 1, $61, 0] + | void main () { + | byte i + | output = 0 + | for i,0,until,8 { + | output <<= 1 + | output |= input[i] & 1 + | } + | } + """.stripMargin){m => + m.readByte(0xc000) should equal(0x56) + } + } + + test("Pack bits into word") { + EmuBenchmarkRun(""" + | word output @$C000 + | array input = [$F0, 1, 0, $41, $10, 1, $61, 0, + | 1, 1, 0, 0, 0, 0, 1, 1] + | void main () { + | byte i + | output = 0 + | for i,0,until,16 { + | output <<= 1 + | output |= input[i] & 1 + | } + | } + """.stripMargin){m => + m.readWord(0xc000) should equal(0x56C3) + } + } + + test("Pack bits into byte using plus") { + EmuBenchmarkRun(""" + | byte output @$C000 + | array input = [$F0, 1, 0, $41, $10, 1, $61, 0] + | void main () { + | byte i + | output = 0 + | for i,0,until,8 { + | output <<= 1 + | output += (input[i] & 1) + | } + | } + """.stripMargin){m => + m.readByte(0xc000) should equal(0x56) + } + } +}