diff --git a/src/main/scala/millfork/assembly/z80/ZLine.scala b/src/main/scala/millfork/assembly/z80/ZLine.scala index ac875688..2acccfc3 100644 --- a/src/main/scala/millfork/assembly/z80/ZLine.scala +++ b/src/main/scala/millfork/assembly/z80/ZLine.scala @@ -89,6 +89,14 @@ object ZLine { def jump(label: Label, condition: ZRegisters): ZLine = ZLine(JP, condition, label.toAddress) + def jumpR(label: String): ZLine = ZLine(JR, NoRegisters, Label(label).toAddress) + + def jumpR(label: Label): ZLine = ZLine(JR, NoRegisters, label.toAddress) + + def jumpR(label: String, condition: ZRegisters): ZLine = ZLine(JR, condition, Label(label).toAddress) + + def jumpR(label: Label, condition: ZRegisters): ZLine = ZLine(JR, condition, label.toAddress) + def djnz(label: String): ZLine = ZLine(DJNZ, NoRegisters, Label(label).toAddress) def djnz(label: Label): ZLine = ZLine(DJNZ, NoRegisters, label.toAddress) @@ -144,6 +152,10 @@ object ZLine { def ldViaIx(target: ZRegister.Value, sourceOffset: Int): ZLine = ZLine(LD, TwoRegistersOffset(target, ZRegister.MEM_IX_D, sourceOffset), Constant.Zero) def ldViaIx(targetOffset: Int, source: ZRegister.Value): ZLine = ZLine(LD, TwoRegistersOffset(ZRegister.MEM_IX_D, source, targetOffset), Constant.Zero) + + def ldViaIy(target: ZRegister.Value, sourceOffset: Int): ZLine = ZLine(LD, TwoRegistersOffset(target, ZRegister.MEM_IY_D, sourceOffset), Constant.Zero) + + def ldViaIy(targetOffset: Int, source: ZRegister.Value): ZLine = ZLine(LD, TwoRegistersOffset(ZRegister.MEM_IY_D, source, targetOffset), Constant.Zero) } case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Constant, elidable: Boolean = true) extends AbstractCode { @@ -228,10 +240,10 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case JP | JR | DJNZ | CALL => val ps = registers match { case NoRegisters => s" $parameter" - case IfFlagSet(ZFlag.P) => " PO,$parameter" - case IfFlagClear(ZFlag.P) => " PE,$parameter" - case IfFlagSet(ZFlag.S) => " M,$parameter" - case IfFlagClear(ZFlag.S) => " P,$parameter" + case IfFlagSet(ZFlag.P) => s" PO,$parameter" + case IfFlagClear(ZFlag.P) => s" PE,$parameter" + case IfFlagSet(ZFlag.S) => s" M,$parameter" + case IfFlagClear(ZFlag.S) => s" P,$parameter" case IfFlagSet(f) => s" $f,$parameter" case IfFlagClear(f) => s" N$f,$parameter" case OneRegister(r) => s" (${asAssemblyString(r)})" @@ -277,6 +289,7 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case TwoRegisters(_, MEM_IY_D) => r == IYH || r == IYL case TwoRegisters(_, MEM_ABS_8 | MEM_ABS_16 | IMM_8 | IMM_16) => false case TwoRegisters(_, s) => r == s + case TwoRegistersOffset(_, s, _) => r == s case _ => false }) || (registers match { case TwoRegisters(MEM_HL, _) => r == H || r == L @@ -294,6 +307,7 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case TwoRegisters(_, MEM_IX_D | IX) => r == IXH || r == IXL case TwoRegisters(_, MEM_IY_D | IY) => r == IYH || r == IYL case TwoRegisters(_, s) => r == s + case TwoRegistersOffset(_, s, _) => r == s case _ => false } case ADD | ADC | OR | XOR | CP | SUB | SBC => registers match { @@ -304,6 +318,7 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case OneRegister(MEM_IY_D) => r == IYH || r == IYL || r == A case OneRegister(IMM_8 | IMM_16) => r == A case OneRegister(s) => r == s || r == A + case OneRegisterOffset(s, _) => r == s case _ => r == A } case INC | DEC | RL | RLC | RR | RRC | SLA | SLL | SRA | SRL => registers match { @@ -313,6 +328,7 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case OneRegister(MEM_IX_D) => r == IXH || r == IXL case OneRegister(MEM_IY_D) => r == IYH || r == IYL case OneRegister(s) => r == s + case OneRegisterOffset(s, _) => r == s case _ => false } case INC_16 | DEC_16 | PUSH => registers match { @@ -322,13 +338,14 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case OneRegister(IX) => r == IXH || r == IXL case OneRegister(IY) => r == IYH || r == IYL case OneRegister(AF) => r == A + case OneRegisterOffset(s, _) => r == s case _ => false } case JP | JR | RET | RETI | RETN | POP | DISCARD_A | DISCARD_BCDEIX | DISCARD_HL | DISCARD_F => false case DJNZ => r == B - case DAA => r == A + case DAA | NEG => r == A case _ => true // TODO } } @@ -348,6 +365,14 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case OneRegisterOffset(s, p) => r == s && o == p case _ => false } + case POP | INC_16 | DEC_16 => registers match { + case OneRegister(IX | IY) => true + case _ => false + } + case LD_16 | ADD_16 => registers match { + case TwoRegisters(IX | IY, _) => true + case _ => false + } case _ => false // TODO } case _ => changesRegister(r) @@ -370,6 +395,7 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta opcode match { case LD => registers match { case TwoRegisters(s, _) => r == s + case TwoRegistersOffset(s, _, _) => r == s case _ => false } case LD_16 | ADD_16 | SBC_16 | ADC_16 => registers match { @@ -379,10 +405,12 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case TwoRegisters(IX, _) => r == IXH || r == IXL case TwoRegisters(IY, _) => r == IYH || r == IYL case TwoRegisters(s, _) => r == s + case TwoRegistersOffset(s, _, _) => r == s case _ => false } case INC | DEC | RL | RLC | RR | RRC | SLA | SLL | SRA | SRL => registers match { case OneRegister(s) => r == s + case OneRegisterOffset(s, _) => r == s case _ => false } case INC_16 | DEC_16 | POP => registers match { @@ -392,12 +420,13 @@ case class ZLine(opcode: ZOpcode.Value, registers: ZRegisters, parameter: Consta case OneRegister(IX) => r == IXH || r == IXL case OneRegister(IY) => r == IYH || r == IYL case OneRegister(AF) => r == A + case OneRegisterOffset(s, _) => r == s case _ => false } case JP | JR | RET | RETI | RETN | POP | DISCARD_A | DISCARD_BCDEIX | DISCARD_HL | DISCARD_F => false - case ADD | ADC | OR | XOR | SUB | SBC | DAA => r == A + case ADD | ADC | OR | XOR | SUB | SBC | DAA | NEG => r == A case CP => false case DJNZ => r == B case _ => true // TODO diff --git a/src/main/scala/millfork/assembly/z80/ZOpcode.scala b/src/main/scala/millfork/assembly/z80/ZOpcode.scala index 5960fa58..49e7d312 100644 --- a/src/main/scala/millfork/assembly/z80/ZOpcode.scala +++ b/src/main/scala/millfork/assembly/z80/ZOpcode.scala @@ -32,6 +32,15 @@ object ZOpcodeClasses { val CbInstructions = Set(SLA, SRA, SRL, SLL, BIT, RES, SET) val CbInstructionsUnlessA = Set(RLC, RRC, RL, RR) + val NoopDiscards = Set(DISCARD_F, DISCARD_A, DISCARD_HL, DISCARD_BCDEIX) + + val ChangesAFAlways = Set( // TODO: ! + DAA, ADD, ADC, SUB, SBC, XOR, OR, AND, INC, DEC, + SCF, CCF, NEG, + ADD_16, ADC_16, SBC_16, INC_16, DEC_16, + INI, INIR, OUTI, OUTIR, IND, INDR, OUTD, OUTDR, + LDI, LDIR, LDD, LDDR, CPI, CPIR, CPD, CPDR, + EXX, CALL, JR, JP, LABEL, DJNZ) val ChangesBCAlways = Set( INI, INIR, OUTI, OUTIR, IND, INDR, OUTD, OUTDR, LDI, LDIR, LDD, LDDR, CPI, CPIR, CPD, CPDR, diff --git a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodZ80Optimizations.scala b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodZ80Optimizations.scala index 0366aaf5..18900748 100644 --- a/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodZ80Optimizations.scala +++ b/src/main/scala/millfork/assembly/z80/opt/AlwaysGoodZ80Optimizations.scala @@ -3,7 +3,7 @@ package millfork.assembly.z80.opt import millfork.assembly.AssemblyOptimization import millfork.assembly.z80._ import millfork.assembly.z80.ZOpcode._ -import millfork.env.{Constant, NumericConstant} +import millfork.env.{CompoundConstant, Constant, MathOperator, NumericConstant} import millfork.node.ZRegister /** @@ -66,7 +66,7 @@ object AlwaysGoodZ80Optimizations { ) val ReloadingKnownValueFromMemory = new RuleBasedAssemblyOptimization("Reloading known value from memory", - needsFlowInfo = FlowInfoRequirement.NoRequirement, + needsFlowInfo = FlowInfoRequirement.ForwardFlow, for7Registers(register => Is8BitLoad(ZRegister.MEM_HL, register) ~ (Linear & Not(Changes(ZRegister.H)) & Not(Changes(ZRegister.L)) & Not(ChangesMemory) & Not(Changes(register)) & Not(IsRegular8BitLoadFrom(ZRegister.MEM_HL))).* ~ @@ -82,42 +82,74 @@ object AlwaysGoodZ80Optimizations { ), (Is8BitLoad(ZRegister.MEM_ABS_8, ZRegister.A) & MatchParameter(0)).captureLine(1) ~ (Linear & DoesntChangeMemoryAt(1) & Not(Changes(ZRegister.A))).* ~ - (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_ABS_8) & MatchParameter(0)) ~~> { code => code.init } + (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_ABS_8) & MatchParameter(0)) ~~> { code => code.init }, + + (Is8BitLoad(ZRegister.MEM_HL, ZRegister.A) & MatchConstantInHL(0)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Changes(ZRegister.A))).* ~ + (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_ABS_8) & MatchParameter(0)) ~~> { code => code.init }, + + (Is8BitLoad(ZRegister.MEM_ABS_8, ZRegister.A) & MatchParameter(0)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Changes(ZRegister.A))).* ~ + (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_HL) & MatchConstantInHL(0)) ~~> { code => code.init }, + + (Is8BitLoad(ZRegister.MEM_HL, ZRegister.A) & MatchConstantInHL(0)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Changes(ZRegister.A))).* ~ + (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_HL) & MatchConstantInHL(0)) ~~> { code => code.init }, + + (Is16BitLoad(ZRegister.MEM_ABS_16, ZRegister.HL) & MatchParameter(0)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Changes(ZRegister.HL))).* ~ + (Elidable & Is16BitLoad(ZRegister.HL, ZRegister.MEM_ABS_16) & MatchParameter(0)) ~~> { code => code.init }, + + (Is8BitLoad(ZRegister.MEM_ABS_8, ZRegister.A) & MatchParameter(0) & MatchRegister(ZRegister.A, 2)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Is8BitLoad(ZRegister.A, ZRegister.MEM_ABS_8))).* ~ + (Elidable & Is8BitLoad(ZRegister.A, ZRegister.MEM_ABS_8) & MatchParameter(0)) ~~> { (code, ctx) => + code.init :+ ZLine.ldImm8(ZRegister.A, ctx.get[Int](2)) + }, + + (Is16BitLoad(ZRegister.MEM_ABS_16, ZRegister.HL) & MatchParameter(0) & MatchConstantInHL(2)).captureLine(1) ~ + (Linear & DoesntChangeMemoryAt(1) & Not(Is16BitLoad(ZRegister.HL, ZRegister.MEM_ABS_16))).* ~ + (Elidable & Is16BitLoad(ZRegister.HL, ZRegister.MEM_ABS_16) & MatchParameter(0)) ~~> { (code, ctx) => + code.init :+ ZLine.ldImm16(ZRegister.HL, ctx.get[Constant](2)) + }, ) val PointlessLoad = new RuleBasedAssemblyOptimization("Pointless load", needsFlowInfo = FlowInfoRequirement.BackwardFlow, + // 0-6 for7Registers(register => (Elidable & Is8BitLoadTo(register) & DoesntMatterWhatItDoesWith(register)) ~~> (_ => Nil) ), + // 7-11 for5LargeRegisters(register => (Elidable & Is16BitLoadTo(register) & DoesntMatterWhatItDoesWith(register)) ~~> (_ => Nil) ), + // 12-18 for7Registers(register => (Is8BitLoad(register, ZRegister.IMM_8) & MatchImmediate(0)) ~ (Linear & Not(Changes(register))).* ~ (Elidable & Is8BitLoad(register, ZRegister.IMM_8) & MatchImmediate(0)) ~~> (_.init) ), + // 19-23 for5LargeRegisters(register => (Is16BitLoad(register, ZRegister.IMM_16) & MatchImmediate(0)) ~ (Linear & Not(Changes(register))).* ~ (Elidable & Is16BitLoad(register, ZRegister.IMM_16) & MatchImmediate(0)) ~~> (_.init) ), - + // 24 (Elidable & Is8BitLoadTo(ZRegister.MEM_HL)) ~ (Linear & Not(ConcernsMemory) & Not(Changes(ZRegister.HL))).* ~ Is8BitLoadTo(ZRegister.MEM_HL) ~~> (_.tail), - + // 25 (Elidable & Is8BitLoadTo(ZRegister.MEM_DE)) ~ (Linear & Not(ConcernsMemory) & Not(Changes(ZRegister.DE))).* ~ Is8BitLoadTo(ZRegister.MEM_DE) ~~> (_.tail), - + // 26 (Elidable & Is8BitLoadTo(ZRegister.MEM_BC)) ~ (Linear & Not(ConcernsMemory) & Not(Changes(ZRegister.BC))).* ~ Is8BitLoadTo(ZRegister.MEM_BC) ~~> (_.tail), - + // 27 (Elidable & MatchTargetIxOffsetOf8BitLoad(0) & MatchUnimportantIxOffset(0)) ~~> (_ => Nil), - + // 28-34 for7Registers(register => (Elidable & Is8BitLoadTo(register) & NoOffset & MatchSourceRegisterAndOffset(1)) ~ (Linear & Not(Concerns(register)) & DoesntChangeMatchedRegisterAndOffset(1)).* ~ @@ -132,10 +164,29 @@ object AlwaysGoodZ80Optimizations { }, head.parameter) } ), - + // 35-41 for7Registers(register => (Elidable & Is8BitLoad(register, register)) ~~> (_ => Nil) ), + // 42-48 + for7Registers(register => + (Elidable & Is8BitLoadTo(register) & MatchSourceRegisterAndOffset(0)) ~ + (Linear & Not(Concerns(register)) & DoesntChangeMatchedRegisterAndOffset(0)).* ~ + (Elidable & HasOpcodeIn(Set(ADD, ADC, XOR, OR, AND, CP, SUB, SBC)) & HasRegisters(OneRegister(register)) & DoesntMatterWhatItDoesWith(register)) ~~> ((code,ctx) => + code.tail.init :+ code.last.copy(registers = ctx.get[RegisterAndOffset](0).toOneRegister) + ) + ), + + ) + + val PointlessStackStashing = new RuleBasedAssemblyOptimization("Pointless stack stashing", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + // 0-4 + for5LargeRegisters(register => { + (Elidable & HasOpcode(PUSH) & HasRegisterParam(register)) ~ + (Linear & Not(HasOpcode(POP)) & Not(Changes(register))).* ~ + (Elidable & HasOpcode(POP) & HasRegisterParam(register)) ~~> (_.tail.init) + }), ) @@ -195,6 +246,22 @@ object AlwaysGoodZ80Optimizations { simplifiable16BitAddWithSplitTarget(ZRegister.IYH, ZRegister.IYL, ZRegister.IY, ZRegister.BC), simplifiable16BitAddWithSplitTarget(ZRegister.IYH, ZRegister.IYL, ZRegister.IY, ZRegister.DE), + (Elidable & HasOpcode(ADD_16) & HasRegisters(TwoRegisters(ZRegister.HL, ZRegister.BC)) & MatchRegister(ZRegister.BC, 0) & MatchRegister(ZRegister.HL, 1) & DoesntMatterWhatItDoesWithFlags) ~~> { (code, ctx) => + List(ZLine.ldImm16(ZRegister.HL, ctx.get[Int](0) + ctx.get[Int](1))) + }, + (Elidable & HasOpcode(ADD_16) & HasRegisters(TwoRegisters(ZRegister.HL, ZRegister.DE)) & MatchRegister(ZRegister.DE, 0) & MatchRegister(ZRegister.HL, 1) & DoesntMatterWhatItDoesWithFlags) ~~> { (code, ctx) => + List(ZLine.ldImm16(ZRegister.HL, ctx.get[Int](0) + ctx.get[Int](1))) + }, + + + (Elidable & HasOpcode(ADD_16) & HasRegisters(TwoRegisters(ZRegister.HL, ZRegister.BC)) & MatchRegister(ZRegister.BC, 0) & MatchConstantInHL(1) & DoesntMatterWhatItDoesWithFlags) ~~> { (code, ctx) => + List(ZLine.ldImm16(ZRegister.HL, (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify)) + }, + (Elidable & HasOpcode(ADD_16) & HasRegisters(TwoRegisters(ZRegister.HL, ZRegister.DE)) & MatchRegister(ZRegister.DE, 0) & MatchConstantInHL(1) & DoesntMatterWhatItDoesWithFlags) ~~> { (code, ctx) => + List(ZLine.ldImm16(ZRegister.HL, (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify)) + }, + + (Elidable & Is8BitLoad(ZRegister.D, ZRegister.H)) ~ (Elidable & Is8BitLoad(ZRegister.E, ZRegister.L)) ~ (Elidable & Is8BitLoadTo(ZRegister.L)) ~ @@ -216,7 +283,58 @@ object AlwaysGoodZ80Optimizations { code.last) }, - (Elidable & HasOpcodeIn(Set(ADD, OR, AND, XOR, SUB)) & Has8BitImmediate(0) & DoesntMatterWhatItDoesWithFlags) ~~> (_ => Nil), + (Elidable & HasOpcodeIn(Set(ADD, OR, XOR, SUB)) & Has8BitImmediate(0) & DoesntMatterWhatItDoesWithFlags) ~~> (_ => Nil), + (Elidable & HasOpcode(AND) & Has8BitImmediate(0xff) & DoesntMatterWhatItDoesWithFlags) ~~> (_ => Nil), + (Elidable & HasOpcode(AND) & Has8BitImmediate(0) & DoesntMatterWhatItDoesWithFlags) ~~> (_ => List(ZLine.ldImm8(ZRegister.A, 0))), + (Elidable & HasOpcode(AND) & Has8BitImmediate(0) & DoesntMatterWhatItDoesWithFlags) ~~> (_ => List(ZLine.ldImm8(ZRegister.A, 0))), + + + (Elidable & HasOpcode(OR) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0)) ~ + (Elidable & HasOpcodeIn(Set(JP, JR)) & HasRegisters(IfFlagSet(ZFlag.S)) & DoesntMatterWhatItDoesWithFlags) ~ + Where(ctx => ctx.get[Constant](1).isInstanceOf[NumericConstant]) ~~> { (code, ctx) => + val value = (ctx.get[Int](0) | ctx.get[NumericConstant](1).value).toInt & 0xff + if (value.&(0x80) == 0) List(ZLine.ldImm8(ZRegister.A, value)) + else List(ZLine.ldImm8(ZRegister.A, value), code.last.copy(registers = NoRegisters)) + }, + + (Elidable & HasOpcode(OR) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0)) ~ + (Elidable & HasOpcodeIn(Set(JP, JR)) & HasRegisters(IfFlagClear(ZFlag.S)) & DoesntMatterWhatItDoesWithFlags) ~ + Where(ctx => ctx.get[Constant](1).isInstanceOf[NumericConstant]) ~~> { (code, ctx) => + val value = (ctx.get[Int](0) | ctx.get[NumericConstant](1).value).toInt & 0xff + if (value.&(0x80) != 0) List(ZLine.ldImm8(ZRegister.A, value)) + else List(ZLine.ldImm8(ZRegister.A, value), code.last.copy(registers = NoRegisters)) + }, + + (Elidable & HasOpcode(ADD) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, (ctx.get[Constant](1) + ctx.get[Int](0)).quickSimplify)) + }, + + (Elidable & HasOpcode(SUB) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, (NumericConstant(ctx.get[Int](0) & 0xff, 1) - ctx.get[Constant](1)).quickSimplify)) + }, + + (Elidable & HasOpcode(OR) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.Or, NumericConstant(ctx.get[Int](0) & 0xff, 1), ctx.get[Constant](1)).quickSimplify)) + }, + + (Elidable & HasOpcode(XOR) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.Exor, NumericConstant(ctx.get[Int](0) & 0xff, 1), ctx.get[Constant](1)).quickSimplify)) + }, + + (Elidable & HasOpcode(AND) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.And, NumericConstant(ctx.get[Int](0) & 0xff, 1), ctx.get[Constant](1)).quickSimplify)) + }, + + (Elidable & HasOpcode(ADD) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0)) ~ + (Elidable & HasOpcode(DAA) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.DecimalPlus, NumericConstant(ctx.get[Int](0) & 0xff, 1), ctx.get[Constant](1)).quickSimplify)) + }, + + (Elidable & HasOpcode(SUB) & Match8BitImmediate(1) & MatchRegister(ZRegister.A, 0)) ~ + (Elidable & HasOpcode(DAA) & DoesntMatterWhatItDoesWithFlags) ~~> {(code, ctx) => + List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.DecimalMinus, NumericConstant(ctx.get[Int](0) & 0xff, 1), ctx.get[Constant](1)).quickSimplify)) + }, + ) val FreeHL = new RuleBasedAssemblyOptimization("Free HL", @@ -252,16 +370,32 @@ object AlwaysGoodZ80Optimizations { )), ) + val UnusedCodeRemoval = new RuleBasedAssemblyOptimization("Unreachable code removal", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcodeIn(Set(JP, JR)) & HasRegisters(NoRegisters)) ~ (Not(HasOpcode(LABEL)) & Elidable).+ ~~> (c => c.head :: Nil) + ) + val UnusedLabelRemoval = new RuleBasedAssemblyOptimization("Unused label removal", needsFlowInfo = FlowInfoRequirement.JustLabels, (Elidable & HasOpcode(LABEL) & HasCallerCount(0)) ~~> (_ => Nil) ) + val BranchInPlaceRemoval = new RuleBasedAssemblyOptimization("Branch in place", + needsFlowInfo = FlowInfoRequirement.NoRequirement, + (HasOpcodeIn(Set(JP, JR)) & MatchJumpTarget(0) & Elidable) ~ + HasOpcodeIn(ZOpcodeClasses.NoopDiscards).* ~ + (HasOpcode(LABEL) & MatchJumpTarget(0)) ~~> (c => c.last :: Nil), + ) + val All: List[AssemblyOptimization[ZLine]] = List[AssemblyOptimization[ZLine]]( + BranchInPlaceRemoval, + EmptyMemoryStoreRemoval, FreeHL, PointlessLoad, + PointlessStackStashing, ReloadingKnownValueFromMemory, SimplifiableMaths, + UnusedCodeRemoval, UnusedLabelRemoval, UsingKnownValueFromAnotherRegister, ) diff --git a/src/main/scala/millfork/assembly/z80/opt/CoarseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/z80/opt/CoarseFlowAnalyzer.scala index b2bb7783..201db64a 100644 --- a/src/main/scala/millfork/assembly/z80/opt/CoarseFlowAnalyzer.scala +++ b/src/main/scala/millfork/assembly/z80/opt/CoarseFlowAnalyzer.scala @@ -77,6 +77,10 @@ object CoarseFlowAnalyzer { case ZLine(CP, _, _, _) => currentStatus = currentStatus.copy(cf = AnyStatus, zf = AnyStatus, sf = AnyStatus, pf = AnyStatus, hf = AnyStatus) + case ZLine(LD_16, TwoRegisters(t, ZRegister.IMM_16), NumericConstant(value, _), _) => + currentStatus = currentStatus.setRegister(t, SingleStatus(value.toInt)) + case ZLine(LD_16, TwoRegisters(ZRegister.HL, ZRegister.IMM_16), xx, _) => + currentStatus = currentStatus.setHL(SingleStatus(xx)) case ZLine(LD, TwoRegisters(t, ZRegister.IMM_8), NumericConstant(value, _), _) => currentStatus = currentStatus.setRegister(t, SingleStatus(value.toInt)) case ZLine(LD, TwoRegistersOffset(t, ZRegister.IMM_8, o), NumericConstant(value, _), _) => @@ -88,6 +92,14 @@ object CoarseFlowAnalyzer { case ZLine(ADD_16, TwoRegisters(t, s), _, _) => currentStatus = currentStatus.copy(cf = AnyStatus, zf = AnyStatus, sf = AnyStatus, pf = AnyStatus, hf = AnyStatus) .setRegister(t, (currentStatus.getRegister(t) <*> currentStatus.getRegister(s)) ((m, n) => (m + n) & 0xffff)) + + case ZLine(SLA, OneRegister(r), _, _) => + currentStatus = currentStatus.copy(cf = AnyStatus, zf = AnyStatus, sf = AnyStatus, pf = AnyStatus, hf = AnyStatus) + .setRegister(r, currentStatus.getRegister(r).map(_.<<(1).&(0xff))) + case ZLine(SRL, OneRegister(r), _, _) => + currentStatus = currentStatus.copy(cf = AnyStatus, zf = AnyStatus, sf = AnyStatus, pf = AnyStatus, hf = AnyStatus) + .setRegister(r, currentStatus.getRegister(r).map(_.>>(1).&(0x7f))) + case ZLine(opcode, registers, _, _) => currentStatus = currentStatus.copy(cf = AnyStatus, zf = AnyStatus, sf = AnyStatus, pf = AnyStatus, hf = AnyStatus) if (ZOpcodeClasses.ChangesAAlways(opcode)) currentStatus = currentStatus.copy(a = AnyStatus) diff --git a/src/main/scala/millfork/assembly/z80/opt/CpuStatus.scala b/src/main/scala/millfork/assembly/z80/opt/CpuStatus.scala index 9683d73b..66c0ebc7 100644 --- a/src/main/scala/millfork/assembly/z80/opt/CpuStatus.scala +++ b/src/main/scala/millfork/assembly/z80/opt/CpuStatus.scala @@ -3,6 +3,7 @@ package millfork.assembly.z80.opt import millfork.assembly.opt._ import millfork.assembly.z80.ZFlag +import millfork.env.{Constant, NumericConstant} import millfork.node.ZRegister /** @@ -16,6 +17,7 @@ case class CpuStatus(a: Status[Int] = UnknownStatus, e: Status[Int] = UnknownStatus, h: Status[Int] = UnknownStatus, l: Status[Int] = UnknownStatus, + hl: Status[Constant] = UnknownStatus, ixh: Status[Int] = UnknownStatus, ixl: Status[Int] = UnknownStatus, iyh: Status[Int] = UnknownStatus, @@ -36,8 +38,8 @@ case class CpuStatus(a: Status[Int] = UnknownStatus, case ZRegister.C => this.copy(c = value) case ZRegister.D => this.copy(d = value) case ZRegister.E => this.copy(e = value) - case ZRegister.H => this.copy(h = value) - case ZRegister.L => this.copy(l = value) + case ZRegister.H => this.copy(h = value, hl = AnyStatus) + case ZRegister.L => this.copy(l = value, hl = AnyStatus) case ZRegister.IXH => this.copy(ixh = value) case ZRegister.IXL => this.copy(ixl = value) case ZRegister.IYH => this.copy(iyh = value) @@ -54,7 +56,7 @@ case class CpuStatus(a: Status[Int] = UnknownStatus, case ZRegister.SP => this case ZRegister.BC => this.copy(b = value.hi, c = value.lo) case ZRegister.DE => this.copy(d = value.hi, e = value.lo) - case ZRegister.HL => this.copy(h = value.hi, l = value.lo) + case ZRegister.HL => this.copy(h = value.hi, l = value.lo, hl = value.map(NumericConstant(_, 2))) case ZRegister.IX => this.copy(ixh = value.hi, ixl = value.lo) case ZRegister.IY => this.copy(iyh = value.hi, iyl = value.lo) case ZRegister.AF => this.copy(a = value.hi, cf = AnyStatus, zf = AnyStatus, hf = AnyStatus, pf = AnyStatus, sf = AnyStatus) @@ -123,6 +125,13 @@ case class CpuStatus(a: Status[Int] = UnknownStatus, hf = this.hf ~ that.hf, ) + def setHL(c: Status[Constant]): CpuStatus = c match { + case SingleStatus(NumericConstant(nn, _)) => this.copy(l = SingleStatus(nn.toInt.&(0xff)), h = SingleStatus(nn.toInt.&(0xff00).>>(8)), hl = c) + case SingleStatus(cc) => this.copy(l = AnyStatus, h = AnyStatus, hl = c) + case AnyStatus => this.copy(l = AnyStatus, h = AnyStatus, hl = AnyStatus) + case UnknownStatus => this.copy(l = UnknownStatus, h = UnknownStatus, hl = UnknownStatus) + } + override def toString: String = { val memRepr = if (memIx.isEmpty) "" else (0 to memIx.keys.max).map(i => memIx.getOrElse(i, UnknownStatus)).mkString("") diff --git a/src/main/scala/millfork/assembly/z80/opt/EmptyMemoryStoreRemoval.scala b/src/main/scala/millfork/assembly/z80/opt/EmptyMemoryStoreRemoval.scala new file mode 100644 index 00000000..5fa836dd --- /dev/null +++ b/src/main/scala/millfork/assembly/z80/opt/EmptyMemoryStoreRemoval.scala @@ -0,0 +1,101 @@ +package millfork.assembly.z80.opt + +import millfork.assembly.opt.SingleStatus +import millfork.assembly.z80.{OneRegister, TwoRegisters, ZLine} +import millfork.assembly.{AssemblyOptimization, OptimizationContext} +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node.ZRegister + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +object EmptyMemoryStoreRemoval extends AssemblyOptimization[ZLine] { + override def name = "Removing pointless stores to automatic variables" + + override def optimize(f: NormalFunction, code: List[ZLine], optimizationContext: OptimizationContext): List[ZLine] = { + val paramVariables = f.params match { +// case NormalParamSignature(List(MemoryVariable(_, typ, _))) if typ.size == 1 => +// Set[String]() + case NormalParamSignature(ps) => + ps.map(_.name).toSet + case _ => + // assembly functions do not get this optimization + return code + } + val flow = FlowAnalyzer.analyze(f, code, optimizationContext.options, FlowInfoRequirement.BothFlows) + import millfork.node.ZRegister._ + val stillUsedVariables = code.flatMap { + case ZLine(_, TwoRegisters(MEM_ABS_8 | MEM_ABS_16, _), MemoryAddressConstant(th), _) => Some(th.name) + case ZLine(_, TwoRegisters(_, MEM_ABS_8 | MEM_ABS_16), MemoryAddressConstant(th), _) => Some(th.name) + case ZLine(_, TwoRegisters(_, IMM_16), MemoryAddressConstant(th), _) => Some(th.name) + case ZLine(_, TwoRegisters(MEM_ABS_8 | MEM_ABS_16, _), CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(_, _)), _) => Some(th.name) + case ZLine(_, TwoRegisters(_, MEM_ABS_8 | MEM_ABS_16), CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(_, _)), _) => Some(th.name) + case ZLine(_, TwoRegisters(_, IMM_16), CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(_, _)), _) => Some(th.name) + case _ => None + }.toSet + val variablesWithAddressesTaken = code.zipWithIndex.flatMap { + case (ZLine(_, _, SubbyteConstant(MemoryAddressConstant(th), _), _), _) => + Some(th.name) + case (ZLine(_, _, SubbyteConstant(CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(_, _)), _), _), _) => + Some(th.name) + case (ZLine(_, + TwoRegisters(ZRegister.MEM_HL, _) | TwoRegisters(_, ZRegister.MEM_HL) | OneRegister(ZRegister.MEM_HL), + _, _), i) => + flow(i)._1.statusBefore.hl match { + case SingleStatus(MemoryAddressConstant(th)) => + if (flow(i)._1.importanceAfter.hlNumeric != Unimportant) Some(th.name) + else None + case SingleStatus(CompoundConstant(MathOperator.Plus, MemoryAddressConstant(th), NumericConstant(_, _))) => + if (flow(i)._1.importanceAfter.hlNumeric != Unimportant) Some(th.name) + else None + case _ => None // TODO: ??? + } + case _ => None + }.toSet + val allLocalVariables = f.environment.getAllLocalVariables + val localVariables = allLocalVariables.filter { + case MemoryVariable(name, typ, VariableAllocationMethod.Auto | VariableAllocationMethod.Zeropage) => + typ.size > 0 && !paramVariables(name) && stillUsedVariables(name) && !variablesWithAddressesTaken(name) + case _ => false + } + + if (localVariables.isEmpty) { + return code + } + + val toRemove = mutable.Set[Int]() + val badVariables = mutable.Set[String]() + + for(v <- localVariables) { + val lifetime = VariableLifetime.apply(v.name, flow) + val lastaccess = lifetime.last + if (lastaccess >= 0) { + val lastVariableAccess = code(lastaccess) + import millfork.assembly.z80.ZOpcode._ + if (lastVariableAccess match { + case ZLine(LD, TwoRegisters(MEM_HL, _), _, true) => true + case ZLine(LD | LD_16, TwoRegisters(MEM_ABS_8 | MEM_ABS_16, _), _, true) => true + case ZLine(INC | DEC, OneRegister(MEM_HL), _, true) => + val importances = flow(lastaccess)._1.importanceAfter + Seq(importances.sf, importances.zf).forall(_ == Unimportant) + case ZLine(SLA | SLL | SRA | SRL | RL | RR | RLC | RRC, OneRegister(MEM_HL), _, true) => + val importances = flow(lastaccess)._1.importanceAfter + Seq(importances.sf, importances.zf, importances.cf).forall(_ == Unimportant) + case _ => false + }) { + badVariables += v.name + toRemove += lastaccess + } + } + } + if (toRemove.isEmpty) { + code + } else { + ErrorReporting.debug(s"Removing pointless store(s) to ${badVariables.mkString(", ")}") + code.zipWithIndex.filter(x => !toRemove(x._2)).map(_._1) + } + } +} diff --git a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala index 5126d96e..591ecf78 100644 --- a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala +++ b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala @@ -36,6 +36,7 @@ case class CpuImportance(a: Importance = UnknownImportance, e: Importance = UnknownImportance, h: Importance = UnknownImportance, l: Importance = UnknownImportance, + hlNumeric: Importance = UnknownImportance, ixh: Importance = UnknownImportance, ixl: Importance = UnknownImportance, iyh: Importance = UnknownImportance, @@ -50,17 +51,18 @@ case class CpuImportance(a: Importance = UnknownImportance, ) { override def toString: String = { val memRepr = if (memIx.isEmpty) "" else (0 to memIx.keys.max).map(i => memIx.getOrElse(i, UnknownImportance)).mkString("") - s"A=$a,B=$b,C=$c,D=$d,E=$e,H=$h,L=$l,IX=$ixh$ixl,Y=$iyh$iyl; Z=$zf,C=$cf,N=$nf,S=$sf,P=$pf,H=$hf; M=" ++ memRepr.padTo(4, ' ') + s"A=$a,B=$b,C=$c,D=$d,E=$e,H=$h,L=$l,IX=$ixh$ixl,Y=$iyh$iyl; Z=$zf,C=$cf,N=$nf,S=$sf,P=$pf,H=$hf; HL=$hlNumeric; M=" ++ memRepr.padTo(4, ' ') } def ~(that: CpuImportance) = new CpuImportance( a = this.a ~ that.a, - b = this.a ~ that.a, - c = this.a ~ that.a, - d = this.a ~ that.a, - e = this.a ~ that.a, - h = this.a ~ that.a, - l = this.a ~ that.a, + b = this.b ~ that.b, + c = this.c ~ that.c, + d = this.d ~ that.d, + e = this.e ~ that.e, + h = this.h ~ that.h, + l = this.l ~ that.l, + hlNumeric = this.hlNumeric ~ that.hlNumeric, ixh = this.ixh ~ that.ixh, ixl = this.ixl ~ that.ixl, iyh = this.iyh ~ that.iyh, @@ -111,9 +113,10 @@ case class CpuImportance(a: Importance = UnknownImportance, case ZRegister.D => this.copy(d = Important) case ZRegister.E => this.copy(e = Important) case ZRegister.DE | ZRegister.MEM_DE => this.copy(d = Important, e = Important) - case ZRegister.H => this.copy(h = Important) - case ZRegister.L => this.copy(l = Important) - case ZRegister.HL | ZRegister.MEM_HL => this.copy(h = Important, l = Important) + case ZRegister.H => this.copy(h = Important, hlNumeric = Important) + case ZRegister.L => this.copy(l = Important, hlNumeric = Important) + case ZRegister.HL => this.copy(h = Important, l = Important, hlNumeric = Important) + case ZRegister.MEM_HL => this.copy(h = Important, l = Important) case ZRegister.IXH => this.copy(ixh = Important) case ZRegister.IXL => this.copy(ixl = Important) case ZRegister.IYH => this.copy(iyh = Important) @@ -137,7 +140,7 @@ case class CpuImportance(a: Importance = UnknownImportance, case ZRegister.MEM_DE => this.copy(d = Important, e = Important) case ZRegister.H => this.copy(h = Unimportant) case ZRegister.L => this.copy(l = Unimportant) - case ZRegister.HL => this.copy(h = Unimportant, l = Unimportant) + case ZRegister.HL => this.copy(h = Unimportant, l = Unimportant, hlNumeric = Unimportant) case ZRegister.MEM_HL => this.copy(h = Important, l = Important) case ZRegister.IXH => this.copy(ixh = Unimportant) case ZRegister.IXL => this.copy(ixl = Unimportant) @@ -193,20 +196,16 @@ object ReverseFlowAnalyzer { } val currentLine = codeArray(i) currentLine match { + case ZLine(LABEL, _, _, _) => () case ZLine(DJNZ, _, MemoryAddressConstant(Label(l)), _) => - val L = l - val labelIndex = codeArray.indexWhere { - case ZLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true - case _ => false - } + val labelIndex = getLabelIndex(codeArray, l) currentImportance = if (labelIndex < 0) finalImportance else (importanceArray(labelIndex) ~ currentImportance).butReadsRegister(ZRegister.B).butReadsFlag(ZFlag.Z) - case ZLine(JP | JR, IfFlagSet(_) | IfFlagClear(_), MemoryAddressConstant(Label(l)), _) => - val L = l - val labelIndex = codeArray.indexWhere { - case ZLine(LABEL, _, MemoryAddressConstant(Label(L)), _) => true - case _ => false - } - currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance + case ZLine(JP | JR, IfFlagSet(flag), MemoryAddressConstant(Label(l)), _) => + val labelIndex = getLabelIndex(codeArray, l) + currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance.butReadsFlag(flag) + case ZLine(JP | JR, IfFlagClear(flag), MemoryAddressConstant(Label(l)), _) => + val labelIndex = getLabelIndex(codeArray, l) + currentImportance = if (labelIndex < 0) finalImportance else importanceArray(labelIndex) ~ currentImportance.butReadsFlag(flag) case ZLine(DISCARD_HL, _, _, _) => currentImportance = currentImportance.copy(h = Unimportant, l = Unimportant) case ZLine(DISCARD_BCDEIX, _, _, _) => @@ -227,15 +226,65 @@ object ReverseFlowAnalyzer { case ZLine(OR | AND, OneRegister(ZRegister.A), _, _) => currentImportance = currentImportance.butReadsRegister(ZRegister.A) - case ZLine(AND | ADD | SUB | OR | XOR | CP, OneRegister(s), _, _) => - currentImportance = currentImportance.butReadsRegister(ZRegister.A).butReadsRegister(s) - case ZLine(ADC | SBC, OneRegister(s), _, _) => - currentImportance = currentImportance.butReadsRegister(ZRegister.A).butReadsRegister(s).butReadsFlag(ZFlag.C) + case ZLine(ADD | SUB | CP, OneRegister(s), _, _) => + currentImportance = currentImportance.butReadsRegister(s).copy( + a = Important, + cf = Unimportant, + zf = Unimportant, + sf = Unimportant, + hf = Unimportant, + pf = Unimportant, + nf = Unimportant + ) + case ZLine(ADD | SUB | CP, OneRegisterOffset(s, o), _, _) => + currentImportance = currentImportance.butReadsRegister(s, o).copy( + a = Important, + cf = Unimportant, + zf = Unimportant, + sf = Unimportant, + hf = Unimportant, + pf = Unimportant, + nf = Unimportant + ) - case ZLine(AND | ADD | SUB | OR | XOR | CP, OneRegisterOffset(s, o), _, _) => - currentImportance = currentImportance.butReadsRegister(ZRegister.A).butReadsRegister(s, o) + case ZLine(AND | OR | XOR, OneRegister(s), _, _) => + currentImportance = currentImportance.butReadsRegister(s).copy( + a = Important, + cf = Unimportant, + zf = Unimportant, + pf = Unimportant, + sf = Unimportant + ) + + case ZLine(AND | OR | XOR, OneRegisterOffset(s, o), _, _) => + currentImportance = currentImportance.butReadsRegister(s, o).copy( + a = Important, + cf = Unimportant, + zf = Unimportant, + pf = Unimportant, + sf = Unimportant + ) + case ZLine(ADC | SBC, OneRegister(s), _, _) => + currentImportance = currentImportance.butReadsRegister(s).copy( + a = Important, + cf = Important, + zf = Unimportant, + sf = Unimportant, + hf = Unimportant, + pf = Unimportant, + nf = Unimportant + ) case ZLine(ADC | SBC, OneRegisterOffset(s, o), _, _) => - currentImportance = currentImportance.butReadsRegister(ZRegister.A).butReadsRegister(s, o).butReadsFlag(ZFlag.C) + currentImportance = currentImportance.butReadsRegister(s, o).copy( + a = Important, + cf = Important, + zf = Unimportant, + sf = Unimportant, + hf = Unimportant, + pf = Unimportant, + nf = Unimportant + ) + case ZLine(INC | DEC | INC_16 | DEC_16, OneRegister(s), _, _) => currentImportance = currentImportance.butReadsRegister(s) @@ -247,6 +296,11 @@ object ReverseFlowAnalyzer { currentImportance = currentImportance.butReadsRegister(r) case ZLine(CALL, NoRegisters, _, _) => currentImportance = finalImportance.copy(memIx = currentImportance.memIx) + + case ZLine(SLA | SRL, OneRegister(r), _, _) => + currentImportance = currentImportance.butReadsRegister(r).butWritesFlag(ZFlag.C).butWritesFlag(ZFlag.Z) + case ZLine(RL | RR | RLC | RRC, OneRegister(r), _, _) => + currentImportance = currentImportance.butReadsRegister(r).butReadsFlag(ZFlag.C).butWritesFlag(ZFlag.Z) case _ => currentImportance = finalImportance // TODO } @@ -259,4 +313,11 @@ object ReverseFlowAnalyzer { importanceArray.toList } + + private def getLabelIndex(codeArray: Array[ZLine], L: String) = { + codeArray.indexWhere { + case ZLine(ZOpcode.LABEL, _, MemoryAddressConstant(Label(L)), _) => true + case _ => false + } + } } diff --git a/src/main/scala/millfork/assembly/z80/opt/RuleBasedAssemblyOptimization.scala b/src/main/scala/millfork/assembly/z80/opt/RuleBasedAssemblyOptimization.scala index 3f768200..0f2caac4 100644 --- a/src/main/scala/millfork/assembly/z80/opt/RuleBasedAssemblyOptimization.scala +++ b/src/main/scala/millfork/assembly/z80/opt/RuleBasedAssemblyOptimization.scala @@ -114,7 +114,7 @@ class AssemblyMatchingContext(val compilationOptions: CompilationOptions) { if (i eq null) { ErrorReporting.fatal(s"Value at index $i is null") } else { - ErrorReporting.fatal(s"Value at index $i is a ${t.getClass.getSimpleName}, not a ${clazz.getSimpleName}") + throw new IllegalStateException(s"Value at index $i is a ${t.getClass.getSimpleName}, not a ${clazz.getSimpleName}") } } } @@ -469,7 +469,14 @@ case class MatchImmediate(i: Int) extends AssemblyLinePattern { } } -case class RegisterAndOffset(register: ZRegister.Value, offset: Int) +case class RegisterAndOffset(register: ZRegister.Value, offset: Int) { + def toOneRegister: ZRegisters = register match { + case ZRegister.MEM_IX_D | ZRegister.MEM_IY_D => OneRegisterOffset(register, offset) + case _ => + if (offset != 0) ??? + OneRegister(register) + } +} case class MatchSourceRegisterAndOffset(i: Int) extends AssemblyLinePattern { override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: ZLine): Boolean = @@ -495,7 +502,10 @@ case class DoesntChangeMatchedRegisterAndOffset(i: Int) extends AssemblyLinePatt import ZRegister._ ro.register match { case AF | SP => false // ? - case MEM_ABS_8 | MEM_ABS_16 | MEM_HL | MEM_DE | MEM_BC => !line.changesMemory + case MEM_ABS_8 | MEM_ABS_16 => !line.changesMemory + case MEM_HL => !line.changesMemory && !line.changesRegister(ZRegister.HL) + case MEM_BC => !line.changesMemory && !line.changesRegister(ZRegister.BC) + case MEM_DE => !line.changesMemory && !line.changesRegister(ZRegister.DE) case _ => !line.changesRegisterAndOffset(ro.register, ro.offset) } } @@ -511,6 +521,24 @@ case class MatchParameter(i: Int) extends AssemblyLinePattern { } } +case class MatchJumpTarget(i: Int) extends AssemblyLinePattern { + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: ZLine): Boolean = + line.registers match { + case NoRegisters | IfFlagClear(_) | IfFlagSet(_) => ctx.addObject(i, line.parameter.quickSimplify) + case _ => false + } +} + +case class MatchConstantInHL(i: Int) extends AssemblyLinePattern { + override def validate(needsFlowInfo: FlowInfoRequirement.Value): Unit = + FlowInfoRequirement.assertForward(needsFlowInfo) + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: ZLine): Boolean = + flowInfo.statusBefore.hl match { + case SingleStatus(value) => ctx.addObject(i, value) + case _ => false + } +} + case class MatchOpcode(i: Int) extends AssemblyLinePattern { override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: ZLine): Boolean = ctx.addObject(i, line.opcode) @@ -829,6 +857,16 @@ case class Has8BitImmediate(i: Int) extends TrivialAssemblyLinePattern { override def toString: String = "#" + i } + +case class Match8BitImmediate(i: Int) extends AssemblyLinePattern { + + override def matchLineTo(ctx: AssemblyMatchingContext, flowInfo: FlowInfo, line: ZLine): Boolean = line.registers match { + case TwoRegisters(_, ZRegister.IMM_8) => ctx.addObject(i, line.parameter) + case OneRegister(ZRegister.IMM_8) => ctx.addObject(i, line.parameter) + case _ => false + } +} + case class HasImmediateWhere(predicate: Int => Boolean) extends TrivialAssemblyLinePattern { override def apply(line: ZLine): Boolean = (line.registers match { diff --git a/src/main/scala/millfork/assembly/z80/opt/VariableLifetime.scala b/src/main/scala/millfork/assembly/z80/opt/VariableLifetime.scala new file mode 100644 index 00000000..79582a9d --- /dev/null +++ b/src/main/scala/millfork/assembly/z80/opt/VariableLifetime.scala @@ -0,0 +1,65 @@ +package millfork.assembly.z80.opt + +import millfork.assembly.opt.SingleStatus +import millfork.assembly.z80.{OneRegister, TwoRegisters, ZLine} +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node.ZRegister + +/** + * @author Karol Stasiak + */ +object VariableLifetime { + + // This only works for non-stack variables. + // TODO: this is also probably very wrong + def apply(variableName: String, codeWithFlow: List[(FlowInfo, ZLine)]): Range = { + val flags = codeWithFlow.map { + case (_, ZLine(_, _, MemoryAddressConstant(MemoryVariable(n, _, _)), _)) => n == variableName + case (_, ZLine(_, _, CompoundConstant(MathOperator.Plus, MemoryAddressConstant(MemoryVariable(n, _, _)), NumericConstant(_, 1)), _)) => n == variableName + case (i, ZLine(_, TwoRegisters(ZRegister.MEM_HL, _) | TwoRegisters(_, ZRegister.MEM_HL) | OneRegister(ZRegister.MEM_HL), _, _)) => + i.statusBefore.hl match { + case SingleStatus(MemoryAddressConstant(MemoryVariable(n, _, _))) => n == variableName + case SingleStatus(CompoundConstant(MathOperator.Plus, MemoryAddressConstant(MemoryVariable(n, _, _)), NumericConstant(_, 1))) => n == variableName + case _ => false + } + case _ => false + } + if (flags.forall(!_)) return Range(0, 0) + var min = flags.indexOf(true) + var max = flags.lastIndexOf(true) + 1 + var changed = true + val labelMap = codeWithFlow.zipWithIndex.flatMap(a => a._1._2.parameter match { + case MemoryAddressConstant(Label(l)) => List(l -> a._2) + case _ => Nil + }).groupBy(_._1).mapValues(_.map(_._2).toSet) + + while (changed) { + changed = false + for ((label, indices) <- labelMap) { + if (indices.exists(i => i >= min && i < max)) { + indices.foreach { i => + val before = max - min + min = min min i + max = max max (i + 1) + if (max - min != before) { + changed = true + } + } + } + } + } + +// ErrorReporting.trace("Lifetime for " + variableName) +// codeWithFlow.zipWithIndex.foreach { +// case ((_, line), index) => +// if (index >= min && index < max) { +// ErrorReporting.trace(f"$line%-30s <") +// } else { +// ErrorReporting.trace(line.toString) +// } +// } + + Range(min, max) + } +} diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala index f31635bd..24d91bd9 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -1,7 +1,7 @@ package millfork.compiler.z80 import millfork.assembly.z80._ -import millfork.compiler._ +import millfork.compiler.{ComparisonType, _} import millfork.env.NumericConstant import millfork.node.{Expression, ZRegister} @@ -13,23 +13,7 @@ object Z80Comparisons { import ComparisonType._ def compile8BitComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { - (ctx.env.eval(l), ctx.env.eval(r)) match { - case (Some(NumericConstant(lc, _)), Some(NumericConstant(rc, _))) => - val constantCondition = compType match { - case Equal => lc == rc - case NotEqual => lc != rc - case GreaterSigned | GreaterUnsigned => lc > rc - case LessOrEqualSigned | LessOrEqualUnsigned => lc <= rc - case GreaterOrEqualSigned | GreaterOrEqualUnsigned=> lc >= rc - case LessSigned | LessUnsigned => lc < rc - } - return branches match { - case BranchIfFalse(b) => if (!constantCondition) List(ZLine.jump(b)) else Nil - case BranchIfTrue(b) => if (constantCondition) List(ZLine.jump(b)) else Nil - case _ => Nil - } - case _ => - } + handleConstantComparison(ctx, compType, l, r, branches).foreach(return _) compType match { case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => return compile8BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) @@ -54,4 +38,182 @@ object Z80Comparisons { } calculateFlags :+ jump } + + private def handleConstantComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): Option[List[ZLine]] = { + (ctx.env.eval(l), ctx.env.eval(r)) match { + case (Some(NumericConstant(lc, _)), Some(NumericConstant(rc, _))) => + val constantCondition = compType match { + case Equal => lc == rc + case NotEqual => lc != rc + case GreaterSigned | GreaterUnsigned => lc > rc + case LessOrEqualSigned | LessOrEqualUnsigned => lc <= rc + case GreaterOrEqualSigned | GreaterOrEqualUnsigned => lc >= rc + case LessSigned | LessUnsigned => lc < rc + } + return Some(branches match { + case BranchIfFalse(b) => if (!constantCondition) List(ZLine.jump(b)) else Nil + case BranchIfTrue(b) => if (constantCondition) List(ZLine.jump(b)) else Nil + case _ => Nil + }) + case _ => + } + None + } + + def compile16BitComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { + handleConstantComparison(ctx, compType, l, r, branches).foreach(return _) + compType match { + case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => + return compile16BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) + case _ => () + } + val calculateLeft = Z80ExpressionCompiler.compileToHL(ctx, l) + val calculateRight = Z80ExpressionCompiler.compileToHL(ctx, r) + val (calculated, useBC) = if (calculateLeft.exists(Z80ExpressionCompiler.changesBC)) { + if (calculateLeft.exists(Z80ExpressionCompiler.changesDE)) { + calculateRight ++ List(ZLine.register(ZOpcode.PUSH, ZRegister.HL)) ++ calculateLeft ++ List(ZLine.register(ZOpcode.POP, ZRegister.BC)) -> false + } else { + calculateRight ++ List(ZLine.ld8(ZRegister.D, ZRegister.H), ZLine.ld8(ZRegister.E, ZRegister.L)) ++ calculateLeft -> false + } + } else { + calculateRight ++ List(ZLine.ld8(ZRegister.B, ZRegister.H), ZLine.ld8(ZRegister.C, ZRegister.L)) ++ calculateLeft -> true + } + val calculateFlags = calculated ++ List( + ZLine.register(ZOpcode.OR, ZRegister.A), + ZLine.registers(ZOpcode.SBC_16, ZRegister.HL, if (useBC) ZRegister.BC else ZRegister.DE)) + if (branches == NoBranching) return calculateFlags + val jump = (compType, branches) match { + case (Equal, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (Equal, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (LessUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case (LessUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case _ => ??? + } + calculateFlags :+ jump + } + + def compileLongRelativeComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { + handleConstantComparison(ctx, compType, l, r, branches).foreach(return _) + compType match { + case Equal | NotEqual => throw new IllegalArgumentException + case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => + return compileLongRelativeComparison(ctx, ComparisonType.flip(compType), r, l, branches) + case _ => () + } + val lt = Z80ExpressionCompiler.getExpressionType(ctx, l) + val rt = Z80ExpressionCompiler.getExpressionType(ctx, r) + val size = lt.size max rt.size + val calculateLeft = Z80ExpressionCompiler.compileByteReads(ctx, l, size, ZExpressionTarget.HL) + val calculateRight = Z80ExpressionCompiler.compileByteReads(ctx, r, size, ZExpressionTarget.BC) + val preserveHl = isBytesFromHL(calculateLeft) + val preserveBc = isBytesFromBC(calculateRight) + val calculateFlags = calculateLeft.zip(calculateRight).zipWithIndex.flatMap { case ((lb, rb), i) => + import ZOpcode._ + import ZRegister._ + val sub = if (i == 0) SUB else SBC + var compareBytes = (lb, rb) match { + case (List(ZLine(LD, TwoRegisters(A, _), _, _)), + List(ZLine(LD, TwoRegisters(A, IMM_8), param, _))) => + lb :+ ZLine.imm8(sub, param) + case (List(ZLine(LD, TwoRegisters(A, _), _, _)), + List(ZLine(LD, TwoRegisters(A, reg), _, _))) if reg != MEM_ABS_8 => + lb :+ ZLine.register(sub, reg) + case (List(ZLine(LD, TwoRegisters(A, _), _, _)), _) => + Z80ExpressionCompiler.stashAFIfChangedF(rb :+ ZLine.ld8(E, A)) ++ lb :+ ZLine.register(sub, E) + case _ => + if (preserveBc || preserveHl) ??? // TODO: preserve HL/BC for the next round of comparisons + var compileArgs = rb ++ List(ZLine.ld8(E, A)) ++ Z80ExpressionCompiler.stashDEIfChanged(lb) ++ List(ZLine.ld8(D, A)) + if (i > 0) compileArgs = Z80ExpressionCompiler.stashAFIfChangedF(compileArgs) + compileArgs ++ List(ZLine.ld8(A, D), ZLine.register(sub, E)) + } + if (i > 0 && preserveBc) compareBytes = Z80ExpressionCompiler.stashBCIfChanged(compareBytes) + if (i > 0 && preserveHl) compareBytes = Z80ExpressionCompiler.stashHLIfChanged(compareBytes) + compareBytes + } + if (branches == NoBranching) return calculateFlags + val jump = (compType, branches) match { + case (Equal, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (Equal, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (LessUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case (LessUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case _ => ??? + } + calculateFlags :+ jump + } + + def compileLongEqualityComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { + handleConstantComparison(ctx, compType, l, r, branches).foreach(return _) + val lt = Z80ExpressionCompiler.getExpressionType(ctx, l) + val rt = Z80ExpressionCompiler.getExpressionType(ctx, r) + val size = lt.size max rt.size + val calculateLeft = Z80ExpressionCompiler.compileByteReads(ctx, l, size, ZExpressionTarget.HL) + val calculateRight = Z80ExpressionCompiler.compileByteReads(ctx, r, size, ZExpressionTarget.BC) + val preserveHl = isBytesFromHL(calculateLeft) + val preserveBc = isBytesFromBC(calculateRight) + val innerLabel = Z80Compiler.nextLabel("cp") + val (jump, epilogue) = (compType, branches) match { + case (Equal, BranchIfTrue(label)) => + ZLine.jump(innerLabel, IfFlagClear(ZFlag.Z)) -> ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (NotEqual, BranchIfFalse(label)) => + ZLine.jump(innerLabel, IfFlagClear(ZFlag.Z)) -> ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (Equal, BranchIfFalse(label)) => + ZLine.jump(innerLabel, IfFlagSet(ZFlag.Z)) -> ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfTrue(label)) => + ZLine.jump(innerLabel, IfFlagSet(ZFlag.Z)) -> ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (_, NoBranching) => ZLine.implied(ZOpcode.NOP) -> ZLine.implied(ZOpcode.NOP) + case _ => throw new IllegalArgumentException + } + val calculateFlags = calculateLeft.zip(calculateRight).zipWithIndex.flatMap { case ((lb, rb), i) => + var compareBytes = { + import ZOpcode._ + import ZRegister._ + (lb, rb) match { + case (_, List(ZLine(LD, TwoRegisters(A, IMM_8), param, _))) => + lb :+ ZLine.imm8(CP, param) + case (List(ZLine(LD, TwoRegisters(A, IMM_8), param, _)), _) => + rb :+ ZLine.imm8(CP, param) + case (List(ZLine(LD, TwoRegisters(A, _), _, _)), + List(ZLine(LD, TwoRegisters(A, reg), _, _))) if reg != MEM_ABS_8 => + lb :+ ZLine.register(CP, reg) + case (List(ZLine(LD, TwoRegisters(A, reg), _, _)), + List(ZLine(LD, TwoRegisters(A, _), _, _))) if reg != MEM_ABS_8 => + rb :+ ZLine.register(CP, reg) + case (List(ZLine(LD, TwoRegisters(A, _), _, _)), _) => + (rb :+ ZLine.ld8(E, A)) ++ lb :+ ZLine.register(CP, E) + case _ => + var actualLb = lb + if (i == 0 && preserveBc) actualLb = Z80ExpressionCompiler.stashBCIfChanged(actualLb) + actualLb ++ List(ZLine.ld8(E, A)) ++ Z80ExpressionCompiler.stashDEIfChanged(rb) :+ ZLine.register(CP, E) + } + } + if (i > 0 && preserveBc) compareBytes = Z80ExpressionCompiler.stashBCIfChanged(compareBytes) + if (i > 0 && preserveHl) compareBytes = Z80ExpressionCompiler.stashHLIfChanged(compareBytes) + if (i != size - 1 && branches != NoBranching) compareBytes :+ jump else compareBytes + } + if (branches == NoBranching) calculateFlags + else calculateFlags ++ List(epilogue, ZLine.label(innerLabel)) + } + + + private def isBytesFromHL(calculateLeft: List[List[ZLine]]) = { + calculateLeft(1) match { + case List(ZLine(ZOpcode.LD, TwoRegisters(ZRegister.A, ZRegister.H), _, _)) => true + case _ => false + } + } + + private def isBytesFromBC(calculateLeft: List[List[ZLine]]) = { + calculateLeft(1) match { + case List(ZLine(ZOpcode.LD, TwoRegisters(ZRegister.A, ZRegister.B), _, _)) => true + case _ => false + } + } } diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index 393adb0b..26805202 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -12,7 +12,7 @@ import millfork.error.ErrorReporting * @author Karol Stasiak */ object ZExpressionTarget extends Enumeration { - val A, HL, NOTHING = Value + val A, HL, BC, DE, NOTHING = Value } object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { @@ -21,6 +21,10 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { def compileToHL(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.HL) + def compileToBC(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.BC) + + def compileToDE(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.DE) + def changesBC(line: ZLine): Boolean = { import ZRegister._ if (ZOpcodeClasses.ChangesBCAlways(line.opcode)) return true @@ -35,6 +39,25 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { false } + def changesAF(line: ZLine): Boolean = { + import ZRegister._ + if (ZOpcodeClasses.ChangesAFAlways(line.opcode)) return true + if (ZOpcodeClasses.ChangesFirstRegister(line.opcode)) return line.registers match { + case TwoRegisters(A, _) => true + case _ => false + } + if (ZOpcodeClasses.ChangesOnlyRegister(line.opcode)) return line.registers match { + case OneRegister(A) => true + case _ => false + } + false + } + + def changesF(line: ZLine): Boolean = { + // TODO + ZOpcodeClasses.ChangesAFAlways(line.opcode) + } + def changesDE(line: ZLine): Boolean = { import ZRegister._ if (ZOpcodeClasses.ChangesDEAlways(line.opcode)) return true @@ -78,6 +101,12 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } + def stashAFIfChanged(lines: List[ZLine]): List[ZLine] = if (lines.exists(changesAF)) + ZLine.register(PUSH, ZRegister.AF) :: (lines :+ ZLine.register(POP, ZRegister.AF)) else lines + + def stashAFIfChangedF(lines: List[ZLine]): List[ZLine] = if (lines.exists(changesF)) + ZLine.register(PUSH, ZRegister.AF) :: (lines :+ ZLine.register(POP, ZRegister.AF)) else lines + def stashBCIfChanged(lines: List[ZLine]): List[ZLine] = if (lines.exists(changesBC)) ZLine.register(PUSH, ZRegister.BC) :: (lines :+ ZLine.register(POP, ZRegister.BC)) else lines @@ -87,27 +116,36 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { def stashHLIfChanged(lines: List[ZLine]): List[ZLine] = if (lines.exists(changesHL)) ZLine.register(PUSH, ZRegister.HL) :: (lines :+ ZLine.register(POP, ZRegister.HL)) else lines - def targetifyA(target: ZExpressionTarget.Value, lines: List[ZLine], isSigned: Boolean): List[ZLine] = target match { - case ZExpressionTarget.NOTHING | ZExpressionTarget.A => lines - case ZExpressionTarget.HL => lines ++ (if (isSigned) { - val label = Z80Compiler.nextLabel("sx") - List( - ZLine.ld8(ZRegister.L, ZRegister.A), - ZLine.ldImm8(ZRegister.H, 0xff), - ZLine.imm8(OR, 0x7f), - ZLine.jump(label, IfFlagSet(ZFlag.S)), // TODO: gameboy has no S flag - ZLine.ldImm8(ZRegister.H, 0), - ZLine.label(label)) - } else { - List( - ZLine.ld8(ZRegister.L, ZRegister.A), - ZLine.ldImm8(ZRegister.H, 0)) - }) + def targetifyA(target: ZExpressionTarget.Value, lines: List[ZLine], isSigned: Boolean): List[ZLine] = { + def toWord(h:ZRegister.Value, l: ZRegister.Value) ={ + lines ++ (if (isSigned) { + val label = Z80Compiler.nextLabel("sx") + List( + ZLine.ld8(l, ZRegister.A), + ZLine.ldImm8(h, 0xff), + ZLine.imm8(OR, 0x7f), + ZLine.jump(label, IfFlagSet(ZFlag.S)), // TODO: gameboy has no S flag + ZLine.ldImm8(h, 0), + ZLine.label(label)) + } else { + List( + ZLine.ld8(l, ZRegister.A), + ZLine.ldImm8(h, 0)) + }) + } + target match { + case ZExpressionTarget.NOTHING | ZExpressionTarget.A => lines + case ZExpressionTarget.HL => toWord(ZRegister.H, ZRegister.L) + case ZExpressionTarget.BC => toWord(ZRegister.B, ZRegister.C) + case ZExpressionTarget.DE => toWord(ZRegister.D, ZRegister.E) + } } def targetifyHL(target: ZExpressionTarget.Value, lines: List[ZLine]): List[ZLine] = target match { case ZExpressionTarget.NOTHING | ZExpressionTarget.HL => lines case ZExpressionTarget.A => lines :+ ZLine.ld8(ZRegister.A, ZRegister.L) + case ZExpressionTarget.BC => lines ++ List(ZLine.ld8(ZRegister.C, ZRegister.L), ZLine.ld8(ZRegister.B, ZRegister.H)) + case ZExpressionTarget.DE => lines ++ List(ZLine.ld8(ZRegister.E, ZRegister.L), ZLine.ld8(ZRegister.D, ZRegister.H)) } def compile(ctx: CompilationContext, expression: Expression, target: ZExpressionTarget.Value, branches: BranchSpec = BranchSpec.None): List[ZLine] = { @@ -121,6 +159,10 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { List(ZLine.ldImm8(ZRegister.A, const)) case ZExpressionTarget.HL => List(ZLine.ldImm16(ZRegister.HL, const)) + case ZExpressionTarget.BC => + List(ZLine.ldImm16(ZRegister.BC, const)) + case ZExpressionTarget.DE => + List(ZLine.ldImm16(ZRegister.DE, const)) case ZExpressionTarget.NOTHING => Nil // TODO } @@ -135,8 +177,9 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case 1 => loadByte(v.toAddress, target) case 2 => target match { case ZExpressionTarget.NOTHING => Nil - case ZExpressionTarget.HL => - List(ZLine.ldAbs16(ZRegister.HL, v)) + case ZExpressionTarget.HL => List(ZLine.ldAbs16(ZRegister.HL, v)) + case ZExpressionTarget.BC => List(ZLine.ldAbs16(ZRegister.BC, v)) + case ZExpressionTarget.DE => List(ZLine.ldAbs16(ZRegister.DE, v)) } case _ => ??? } @@ -148,6 +191,10 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.HL => List(ZLine.ldViaIx(ZRegister.L, v.baseOffset), ZLine.ldViaIx(ZRegister.H, v.baseOffset + 1)) + case ZExpressionTarget.BC => + List(ZLine.ldViaIx(ZRegister.C, v.baseOffset), ZLine.ldViaIx(ZRegister.B, v.baseOffset + 1)) + case ZExpressionTarget.DE => + List(ZLine.ldViaIx(ZRegister.E, v.baseOffset), ZLine.ldViaIx(ZRegister.D, v.baseOffset + 1)) } case _ => ??? } @@ -176,6 +223,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A=> List(ZLine.ld8(ZRegister.A, ZRegister.H)) case ZExpressionTarget.HL=> List(ZLine.ld8(ZRegister.L, ZRegister.H), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.BC=> List(ZLine.ld8(ZRegister.C, ZRegister.H), ZLine.ldImm8(ZRegister.B, 0)) + case ZExpressionTarget.DE=> List(ZLine.ld8(ZRegister.E, ZRegister.H), ZLine.ldImm8(ZRegister.D, 0)) }) } case "lo" => @@ -187,6 +236,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A => List(ZLine.ld8(ZRegister.A, ZRegister.L)) case ZExpressionTarget.HL => List(ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.BC => List(ZLine.ld8(ZRegister.C, ZRegister.L), ZLine.ldImm8(ZRegister.B, 0)) + case ZExpressionTarget.DE => List(ZLine.ld8(ZRegister.E, ZRegister.L), ZLine.ldImm8(ZRegister.D, 0)) }) } case "nonet" => @@ -206,6 +257,24 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } else { ??? } + case ZExpressionTarget.BC => + if (ctx.options.flag(CompilationFlag.EmitExtended80Opcodes)) { + List( + ZLine.ld8(ZRegister.C, ZRegister.A), + ZLine.ldImm8(ZRegister.B, 0), + ZLine.register(RL, ZRegister.B)) + } else { + ??? + } + case ZExpressionTarget.DE => + if (ctx.options.flag(CompilationFlag.EmitExtended80Opcodes)) { + List( + ZLine.ld8(ZRegister.C, ZRegister.A), + ZLine.ldImm8(ZRegister.B, 0), + ZLine.register(RL, ZRegister.B)) + } else { + ??? + } }) } case "&&" => @@ -237,7 +306,9 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case 1 => targetifyA(target, ZBuiltIns.compile8BitOperation(ctx, AND, params), isSigned = false) case 2 => targetifyHL(target, ZBuiltIns.compile16BitOperation(ctx, AND, params)) } - case "*" => ??? + case "*" => + assertAllBytes("Long multiplication not supported", ctx, params) + targetifyA(target, Z80Multiply.compile8BitMultiply(ctx, params), isSigned = false) case "|" => getParamMaxSize(ctx, params) match { case 1 => targetifyA(target, ZBuiltIns.compile8BitOperation(ctx, OR, params), isSigned = false) @@ -278,7 +349,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, "<", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + case _ => Z80Comparisons.compileLongRelativeComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) } } case ">=" => @@ -286,7 +358,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, ">=", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + case _ => Z80Comparisons.compileLongRelativeComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) } } case ">" => @@ -294,7 +367,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, ">", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + case _ => Z80Comparisons.compileLongRelativeComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) } } case "<=" => @@ -302,7 +376,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, "<=", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + case _ => Z80Comparisons.compileLongRelativeComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) } } case "==" => @@ -310,7 +385,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, "==", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, ComparisonType.Equal, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, ComparisonType.Equal, l, r, branches) + case _ => Z80Comparisons.compileLongEqualityComparison(ctx, ComparisonType.Equal, l, r, branches) } } case "!=" => @@ -318,7 +394,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { compileTransitiveRelation(ctx, "!=", params, target, branches) { (l, r) => size match { case 1 => Z80Comparisons.compile8BitComparison(ctx, ComparisonType.NotEqual, l, r, branches) - case _ => ??? + case 2 => Z80Comparisons.compile16BitComparison(ctx, ComparisonType.NotEqual, l, r, branches) + case _ => Z80Comparisons.compileLongEqualityComparison(ctx, ComparisonType.NotEqual, l, r, branches) } } case "+=" => @@ -369,7 +446,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case "*=" => assertAllBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertAssignmentLike(ctx, params) - ??? + Z80Multiply.compile8BitInPlaceMultiply(ctx, l, r) case "*'=" => assertAllBytes("Long multiplication not supported", ctx, params) val (l, r, 1) = assertAssignmentLike(ctx, params) @@ -471,6 +548,24 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } } + def calculateLoadAndStoreForByte(ctx: CompilationContext, expr: LhsExpression): (List[ZLine], List[ZLine]) = { + Z80ExpressionCompiler.calculateAddressToAppropriatePointer(ctx, expr) match { + case Some((LocalVariableAddressViaHL, calculate)) => + (calculate :+ ZLine.ld8(ZRegister.A, ZRegister.MEM_HL)) -> List(ZLine.ld8(ZRegister.MEM_HL, ZRegister.A)) + case Some((LocalVariableAddressViaIX(offset), calculate)) => + (calculate :+ ZLine.ldViaIx(ZRegister.A, offset)) -> List(ZLine.ldViaIx(offset, ZRegister.A)) + case Some((LocalVariableAddressViaIY(offset), calculate)) => + (calculate :+ ZLine.ldViaIy(ZRegister.A, offset)) -> List(ZLine.ldViaIy(offset, ZRegister.A)) + case None => expr match { + case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) => + val lo = calculateLoadAndStoreForByte(ctx, l) + val (_, hiStore) = calculateLoadAndStoreForByte(ctx, h) + lo._1 -> (lo._2 ++ List(ZLine.ldImm8(ZRegister.A, 0)) ++ hiStore) + case _ => ??? + } + } + } + def calculateAddressToAppropriatePointer(ctx: CompilationContext, expr: LhsExpression): Option[(LocalVariableAddressOperand, List[ZLine])] = { val env = ctx.env expr match { @@ -515,6 +610,8 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A => List(ZLine.ldAbs8(ZRegister.A, sourceAddr)) case ZExpressionTarget.HL => List(ZLine.ldAbs8(ZRegister.A, sourceAddr), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.BC => List(ZLine.ldAbs8(ZRegister.A, sourceAddr), ZLine.ld8(ZRegister.C, ZRegister.A), ZLine.ldImm8(ZRegister.B, 0)) + case ZExpressionTarget.DE => List(ZLine.ldAbs8(ZRegister.A, sourceAddr), ZLine.ld8(ZRegister.E, ZRegister.A), ZLine.ldImm8(ZRegister.D, 0)) } } @@ -522,7 +619,9 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { target match { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A => List(ZLine.ldViaIx(ZRegister.A, offset)) - case ZExpressionTarget.HL => List(ZLine.ldViaIx(ZRegister.A, offset), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.HL => List(ZLine.ldViaIx(ZRegister.L, offset), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.BC => List(ZLine.ldViaIx(ZRegister.C, offset), ZLine.ldImm8(ZRegister.B, 0)) + case ZExpressionTarget.DE => List(ZLine.ldViaIx(ZRegister.E, offset), ZLine.ldImm8(ZRegister.D, 0)) } } @@ -530,20 +629,16 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { target match { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A => List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL)) - case ZExpressionTarget.HL => List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.HL => List(ZLine.ld8(ZRegister.L, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.H, 0)) + case ZExpressionTarget.BC => List(ZLine.ld8(ZRegister.C, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.B, 0)) + case ZExpressionTarget.DE => List(ZLine.ld8(ZRegister.E, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.D, 0)) } } def signExtend(targetAddr: Constant, hiRegister: ZRegister.Value, bytes: Int, signedSource: Boolean): List[ZLine] = { if (bytes == 0) return Nil val prepareA = if (signedSource) { - val prefix = if (hiRegister == ZRegister.A) Nil else List(ZLine.ld8(ZRegister.A, hiRegister)) - val label = Z80Compiler.nextLabel("sx") - prefix ++ List( - ZLine.imm8(OR, 0x7f), - ZLine.jump(label, IfFlagSet(ZFlag.S)), - ZLine.ldImm8(ZRegister.A, 0), - ZLine.label(label)) + signExtendHighestByte(hiRegister) } else { List(ZLine.ldImm8(ZRegister.A, 0)) } @@ -551,16 +646,20 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { prepareA ++ fillUpperBytes } + private def signExtendHighestByte(hiRegister: ZRegister.Value) = { + val prefix = if (hiRegister == ZRegister.A) Nil else List(ZLine.ld8(ZRegister.A, hiRegister)) + val label = Z80Compiler.nextLabel("sx") + prefix ++ List( + ZLine.imm8(OR, 0x7f), + ZLine.jump(label, IfFlagSet(ZFlag.S)), + ZLine.ldImm8(ZRegister.A, 0), + ZLine.label(label)) + } + def signExtendViaIX(targetOffset: Int, hiRegister: ZRegister.Value, bytes: Int, signedSource: Boolean): List[ZLine] = { if (bytes == 0) return Nil val prepareA = if (signedSource) { - val prefix = if (hiRegister == ZRegister.A) Nil else List(ZLine.ld8(ZRegister.A, hiRegister)) - val label = Z80Compiler.nextLabel("sx") - prefix ++ List( - ZLine.imm8(OR, 0x7f), - ZLine.jump(label, IfFlagSet(ZFlag.S)), - ZLine.ldImm8(ZRegister.A, 0), - ZLine.label(label)) + signExtendHighestByte(hiRegister) } else { List(ZLine.ldImm8(ZRegister.A, 0)) } @@ -639,8 +738,25 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case (None, offset) => ZLine.ld8(ZRegister.A, ZRegister.L) :: storeA((p.value + offset).quickSimplify, 1, signedSource) } } - //TODO - case SeparateBytesExpression(hi, lo) => ??? + case SeparateBytesExpression(hi: LhsExpression, lo: LhsExpression) => + Z80ExpressionCompiler.stashHLIfChanged(ZLine.ld8(ZRegister.A, ZRegister.L) :: storeA(ctx, lo, signedSource)) ++ + (ZLine.ld8(ZRegister.A, ZRegister.H) :: storeA(ctx, hi, signedSource)) + case _: SeparateBytesExpression => + ErrorReporting.error("Invalid `:`", target.position) + Nil + } + } + + def storeLarge(ctx: CompilationContext, target: LhsExpression, source: Expression): List[ZLine] = { + val env = ctx.env + target match { + case VariableExpression(vname) => + env.get[Variable](vname) match { + case v: Variable => + val size = v.typ.size + compileByteReads(ctx, source, size, ZExpressionTarget.HL).zip(compileByteStores(ctx, target, size)).flatMap(t => t._1 ++ t._2) + } + case _ => ??? } } @@ -679,7 +795,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } } - def compileByteReads(ctx: CompilationContext, rhs: Expression, size: Int): List[List[ZLine]] = { + def compileByteReads(ctx: CompilationContext, rhs: Expression, size: Int, temporaryTarget: ZExpressionTarget.Value): List[List[ZLine]] = { if (size == 1) throw new IllegalArgumentException val env = ctx.env env.eval(rhs) match { @@ -691,20 +807,20 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { env.get[Variable](vname) match { case v: VariableInMemory => List.tabulate(size) { i => - if (i < size) { + if (i < v.typ.size) { List(ZLine.ldAbs8(ZRegister.A, v.toAddress + i)) } else if (v.typ.isSigned) { - ??? + ZLine.ldAbs8(ZRegister.A, v.toAddress + v.typ.size - 1) :: signExtendHighestByte(ZRegister.A) } else { List(ZLine.ldImm8(ZRegister.A, 0)) } } case v: StackVariable => List.tabulate(size) { i => - if (i < size) { + if (i < v.typ.size) { List(ZLine.ldViaIx(ZRegister.A, v.baseOffset + i)) } else if (v.typ.isSigned) { - ??? + ZLine.ldViaIx(ZRegister.A, v.baseOffset + v.typ.size - 1) :: signExtendHighestByte(ZRegister.A) } else { List(ZLine.ldImm8(ZRegister.A, 0)) } @@ -721,15 +837,42 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } } case _ => - List.tabulate(size) { i => - if (i == 0) { - compileToHL(ctx, rhs) :+ ZLine.ld8(ZRegister.A, ZRegister.L) - } else if (i == 1) { - List(ZLine.ld8(ZRegister.A, ZRegister.H)) - } else { - // TODO: signed words? - List(ZLine.ldImm8(ZRegister.A, 0)) - } + val (h, l) = temporaryTarget match { + case ZExpressionTarget.HL => ZRegister.H -> ZRegister.L + case ZExpressionTarget.BC => ZRegister.B -> ZRegister.C + case ZExpressionTarget.DE => ZRegister.D -> ZRegister.E + case _ => throw new IllegalArgumentException("temporaryTarget") + } + val typ = getExpressionType(ctx, rhs) + typ.size match { + case 1 => + List.tabulate(size) { i => + if (i == 0) { + if (typ.isSigned) { + (compileToA(ctx, rhs) :+ ZLine.ld8(l, ZRegister.A)) ++ + signExtendHighestByte(ZRegister.A) ++ List(ZLine.ld8(h, ZRegister.A), ZLine.ld8(ZRegister.A, l)) + } else { + compileToA(ctx, rhs) + } + } else if (typ.isSigned) { + List(ZLine.ld8(ZRegister.A, h)) + } else { + // TODO: signed words? + List(ZLine.ldImm8(ZRegister.A, 0)) + } + } + case 2 => + List.tabulate(size) { i => + if (i == 0) { + compile(ctx, rhs, temporaryTarget, BranchSpec.None) :+ ZLine.ld8(ZRegister.A, l) + } else if (i == 1) { + List(ZLine.ld8(ZRegister.A, h)) + } else { + // TODO: signed words? + List(ZLine.ldImm8(ZRegister.A, 0)) + } + } + case _ => ??? } } } diff --git a/src/main/scala/millfork/compiler/z80/Z80Multiply.scala b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala new file mode 100644 index 00000000..bb6923c0 --- /dev/null +++ b/src/main/scala/millfork/compiler/z80/Z80Multiply.scala @@ -0,0 +1,128 @@ +package millfork.compiler.z80 + +import millfork.assembly.z80._ +import millfork.compiler.{BranchSpec, CompilationContext} +import millfork.env.{CompoundConstant, Constant, MathOperator, NumericConstant} +import millfork.node.{ConstantArrayElementExpression, Expression, LhsExpression, ZRegister} + +/** + * @author Karol Stasiak + */ +object Z80Multiply { + + /** + * Compiles A = A * D + */ + private def multiplication(): List[ZLine] = { + import millfork.assembly.z80.ZOpcode._ + import ZRegister._ + import ZLine._ + val lblAdd = Z80Compiler.nextLabel("mu") + val lblLoop = Z80Compiler.nextLabel("mu") + val lblStart = Z80Compiler.nextLabel("mu") + List( + ld8(E, A), + ldImm8(A, 0), + jumpR(lblStart), + label(lblAdd), + register(ADD, E), + label(lblLoop), + register(SLA, E), + label(lblStart), + register(SRL, D), + jumpR(lblAdd, IfFlagSet(ZFlag.C)), + jumpR(lblLoop, IfFlagClear(ZFlag.Z))) + } + + /** + * Calculate A = l * r + */ + def compile8BitMultiply(ctx: CompilationContext, params: List[Expression]): List[ZLine] = { + var numericConst = 1L + var otherConst: Constant = NumericConstant(1, 1) + val filteredParams = params.filter { expr => + ctx.env.eval(expr) match { + case None => + true + case Some(NumericConstant(n, _)) => + numericConst *= n + false + case Some(c) => + otherConst = CompoundConstant(MathOperator.Times, otherConst, c).loByte.quickSimplify + false + } + } + val productOfConstants = CompoundConstant(MathOperator.Times, otherConst, NumericConstant(numericConst & 0xff, 1)).quickSimplify + (filteredParams, otherConst) match { + case (Nil, NumericConstant(n, _)) => List(ZLine.ldImm8(ZRegister.A, (numericConst * n).toInt)) + case (Nil, _) => List(ZLine.ldImm8(ZRegister.A, productOfConstants)) + case (List(a), NumericConstant(n, _)) => Z80ExpressionCompiler.compileToA(ctx, a) ++ compile8BitMultiply((numericConst * n).toInt) + case (List(a), _) => + compile8BitMultiply(ctx, a, ConstantArrayElementExpression(productOfConstants)) + case (List(a, b), NumericConstant(n, _)) => + compile8BitMultiply(ctx, a, b) ++ compile8BitMultiply((numericConst * n).toInt) + case _ => ??? + } + } + + /** + * Calculate A = l * r + */ + def compile8BitMultiply(ctx: CompilationContext, l: Expression, r: Expression): List[ZLine] = { + (ctx.env.eval(l), ctx.env.eval(r)) match { + case (Some(a), Some(b)) => List(ZLine.ldImm8(ZRegister.A, CompoundConstant(MathOperator.Times, a, b).loByte.quickSimplify)) + case (Some(NumericConstant(count, _)), None) => Z80ExpressionCompiler.compileToA(ctx, r) ++ compile8BitMultiply(count.toInt) + case (None, Some(NumericConstant(count, _))) => Z80ExpressionCompiler.compileToA(ctx, l) ++ compile8BitMultiply(count.toInt) + case _ => + val lb = Z80ExpressionCompiler.compileToA(ctx, l) + val rb = Z80ExpressionCompiler.compileToA(ctx, r) + val load = if (lb.exists(Z80ExpressionCompiler.changesDE)) { + lb ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) ++ Z80ExpressionCompiler.stashDEIfChanged(rb) + } else { + rb ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) ++ lb + } + load ++ multiplication() + } + } + + /** + * Calculate A = l * r + */ + def compile8BitInPlaceMultiply(ctx: CompilationContext, l: LhsExpression, r: Expression): List[ZLine] = { + ctx.env.eval(r) match { + case Some(NumericConstant(count, _)) => + val (load, store) = Z80ExpressionCompiler.calculateLoadAndStoreForByte(ctx, l) + load ++ compile8BitMultiply(count.toInt) ++ store + case Some(c) => + val (load, store) = Z80ExpressionCompiler.calculateLoadAndStoreForByte(ctx, l) + load ++ List(ZLine.ldImm8(ZRegister.D, c)) ++ multiplication() ++ store + case _ => + val (load, store) = Z80ExpressionCompiler.calculateLoadAndStoreForByte(ctx, l) + val rb = Z80ExpressionCompiler.compileToA(ctx, r) + val loadRegisters = if (load.exists(Z80ExpressionCompiler.changesDE)) { + load ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) ++ Z80ExpressionCompiler.stashDEIfChanged(rb) + } else { + rb ++ List(ZLine.ld8(ZRegister.D, ZRegister.A)) ++ load + } + loadRegisters ++ multiplication() ++ store + } + } + + /** + * Calculate A = count * x + */ + def compile8BitMultiply(count: Int): List[ZLine] = { + import millfork.assembly.z80.ZOpcode._ + import ZRegister._ + count match { + case 0 => List(ZLine.ldImm8(A, 0)) + case 1 => Nil + case n if n > 0 && n.-(1).&(n).==(0) => List.fill(Integer.numberOfTrailingZeros(n))(ZLine.register(SLA, A)) + case _ => + ZLine.ld8(E,A) :: Integer.toString(count & 0xff, 2).tail.flatMap{ + case '0' => List(ZLine.register(SLA, A)) + case '1' => List(ZLine.register(SLA, A), ZLine.register(ADD, E)) + }.toList + } + } +} diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index 30348aaf..abcad5ad 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -68,7 +68,7 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { case 0 => ??? case 1 => Z80ExpressionCompiler.compileToA(ctx, source) ++ Z80ExpressionCompiler.storeA(ctx, destination, sourceType.isSigned) case 2 => Z80ExpressionCompiler.compileToHL(ctx, source) ++ Z80ExpressionCompiler.storeHL(ctx, destination, sourceType.isSigned) - case _ => ??? // large object copy + case s => Z80ExpressionCompiler.storeLarge(ctx, destination, source) } case s: IfStatement => compileIfStatement(ctx, s) diff --git a/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala b/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala index d197f343..6c012139 100644 --- a/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala +++ b/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala @@ -368,11 +368,12 @@ object ZBuiltIns { } } val store = Z80ExpressionCompiler.compileByteStores(ctx, lhs, size) - val loadLeft = Z80ExpressionCompiler.compileByteReads(ctx, lhs, size) - val loadRight = Z80ExpressionCompiler.compileByteReads(ctx, rhs, size) + val loadLeft = Z80ExpressionCompiler.compileByteReads(ctx, lhs, size, ZExpressionTarget.HL) + val loadRight = Z80ExpressionCompiler.compileByteReads(ctx, rhs, size, ZExpressionTarget.BC) List.tabulate(size) {i => // TODO: stash things correctly? - val firstPhase = loadRight(i) ++ List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ (loadLeft(i) :+ ZLine.register(if (i==0) opcodeFirst else opcodeLater, ZRegister.E)) + val firstPhase = loadRight(i) ++ List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ + (Z80ExpressionCompiler.stashBCIfChanged(loadLeft(i)) :+ ZLine.register(if (i==0) opcodeFirst else opcodeLater, ZRegister.E)) val secondPhase = if (decimal) firstPhase :+ ZLine.implied(ZOpcode.DAA) else firstPhase secondPhase ++ store(i) }.flatten diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index f920c68a..c2f98e3b 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -101,6 +101,13 @@ object MosNiceFunctionProperty { case object DoesntChangeZpRegister extends NiceFunctionProperty("reg") } +object Z80NiceFunctionProperty { + case object DoesntChangeBC extends NiceFunctionProperty("BC") + case object DoesntChangeDE extends NiceFunctionProperty("DE") + case object DoesntChangeHL extends NiceFunctionProperty("HL") + case object DoesntChangeIY extends NiceFunctionProperty("IY") +} + object MosRegister extends Enumeration { val A, X, Y, AX, AY, YA, XA, XY, YX, AW = Value } diff --git a/src/main/scala/millfork/output/Z80InliningCalculator.scala b/src/main/scala/millfork/output/Z80InliningCalculator.scala index ccca1d77..c9da3c00 100644 --- a/src/main/scala/millfork/output/Z80InliningCalculator.scala +++ b/src/main/scala/millfork/output/Z80InliningCalculator.scala @@ -1,16 +1,76 @@ package millfork.output -import millfork.assembly.z80.ZLine +import millfork.assembly.z80._ import millfork.compiler.AbstractCompiler +import millfork.env.{ExternFunction, Label, MemoryAddressConstant, NormalFunction} + +import scala.collection.GenTraversableOnce /** * @author Karol Stasiak */ object Z80InliningCalculator extends AbstractInliningCalculator[ZLine] { - // TODO + import ZOpcode._ - override def codeForInlining(fname: String, functionsAlreadyKnownToBeNonInlineable: Set[String], code: List[ZLine]): Option[List[ZLine]] = None + private val badOpcodes = Set(RET, RETI, RETN, CALL, BYTE, POP, PUSH) + private val jumpingRelatedOpcodes = Set(LABEL, JP, JR) - override def inline(code: List[ZLine], inlinedFunctions: Map[String, List[ZLine]], compiler: AbstractCompiler[ZLine]): List[ZLine] = code + override def codeForInlining(fname: String, functionsAlreadyKnownToBeNonInlineable: Set[String], code: List[ZLine]): Option[List[ZLine]] = { + if (code.isEmpty) return None + code.last match { + case ZLine(RET, NoRegisters, _, _) => + case _ => return None + } + var result = code.init + while (result.nonEmpty && ZOpcodeClasses.NoopDiscards(result.last.opcode)) { + result = result.init + } + if (result.head.opcode == LABEL && result.head.parameter == Label(fname).toAddress) result = result.tail + if (result.exists { + case ZLine(op, _, MemoryAddressConstant(Label(l)), _) if jumpingRelatedOpcodes(op) => + !l.startsWith(".") + case ZLine(CALL, _, MemoryAddressConstant(th: ExternFunction), _) => false + case ZLine(CALL, _, MemoryAddressConstant(th: NormalFunction), _) => + !functionsAlreadyKnownToBeNonInlineable(th.name) + case ZLine(op, _, _, _) if jumpingRelatedOpcodes(op) || badOpcodes(op) => true + case _ => false + }) return None + Some(result) + } + + def wrap(registers: ZRegisters, compiler: AbstractCompiler[ZLine], lines: List[ZLine]): List[ZLine] = registers match { + case NoRegisters => lines + case IfFlagClear(flag) => + val label = compiler.nextLabel("ai") + ZLine.jump(label, IfFlagSet(flag)) :: (lines :+ ZLine.label(label)) + case IfFlagSet(flag) => + val label = compiler.nextLabel("ai") + ZLine.jump(label, IfFlagClear(flag)) :: (lines :+ ZLine.label(label)) + case _ => throw new IllegalArgumentException("registers") + } + + override def inline(code: List[ZLine], inlinedFunctions: Map[String, List[ZLine]], compiler: AbstractCompiler[ZLine]): List[ZLine] = { + code.flatMap { + case ZLine(CALL, registers, p, true) if inlinedFunctions.contains(p.toString) => + val labelPrefix = compiler.nextLabel("ai") + wrap(registers, compiler, + inlinedFunctions(p.toString).map { + case line@ZLine(_, _, MemoryAddressConstant(Label(label)), _) => + val newLabel = MemoryAddressConstant(Label(labelPrefix + label)) + line.copy(parameter = newLabel) + case l => l + }) + case ZLine(JP | JR, registers, p, true) if inlinedFunctions.contains(p.toString) => + val labelPrefix = compiler.nextLabel("ai") + wrap(registers, compiler, + inlinedFunctions(p.toString).map { + case line@ZLine(_, _, MemoryAddressConstant(Label(label)), _) => + val newLabel = MemoryAddressConstant(Label(labelPrefix + label)) + line.copy(parameter = newLabel) + case l => l + } :+ ZLine.implied(RET)) + case x => List(x) + } + } } diff --git a/src/test/scala/millfork/test/ArraySuite.scala b/src/test/scala/millfork/test/ArraySuite.scala index dd94ebca..26ba5742 100644 --- a/src/test/scala/millfork/test/ArraySuite.scala +++ b/src/test/scala/millfork/test/ArraySuite.scala @@ -67,7 +67,7 @@ class ArraySuite extends FunSuite with Matchers { } test("Array assignment through a pointer") { - val m = EmuUnoptimizedRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | array output [3] @$c000 | pointer p @@ -80,13 +80,14 @@ class ArraySuite extends FunSuite with Matchers { | w = $105 | p[i]:ignored = w | } - """.stripMargin) - m.readByte(0xc001) should equal(1) + """.stripMargin) { m => + m.readByte(0xc001) should equal(1) + } } test("Array in place math") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | array output [4] @$c000 | void main () { diff --git a/src/test/scala/millfork/test/ByteMathSuite.scala b/src/test/scala/millfork/test/ByteMathSuite.scala index cf119350..65bafa5a 100644 --- a/src/test/scala/millfork/test/ByteMathSuite.scala +++ b/src/test/scala/millfork/test/ByteMathSuite.scala @@ -10,7 +10,7 @@ import org.scalatest.{FunSuite, Matchers} class ByteMathSuite extends FunSuite with Matchers { test("Complex expression") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | void main () { @@ -23,7 +23,7 @@ class ByteMathSuite extends FunSuite with Matchers { } test("Byte addition") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | byte a @@ -35,7 +35,7 @@ class ByteMathSuite extends FunSuite with Matchers { } test("Byte addition 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | byte a @@ -47,7 +47,7 @@ class ByteMathSuite extends FunSuite with Matchers { } test("In-place byte addition") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | array output[3] @$c000 | byte a @@ -101,7 +101,7 @@ class ByteMathSuite extends FunSuite with Matchers { } test("In-place byte addition 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | array output[3] @$c000 | void main () { @@ -137,7 +137,7 @@ class ByteMathSuite extends FunSuite with Matchers { } private def multiplyCase1(x: Int, y: Int): Unit = { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( s""" | byte output @$$c000 | void main () { @@ -165,7 +165,7 @@ class ByteMathSuite extends FunSuite with Matchers { } private def multiplyCase2(x: Int, y: Int): Unit = { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( s""" | byte output @$$c000 | void main () { @@ -178,7 +178,7 @@ class ByteMathSuite extends FunSuite with Matchers { } test("Byte multiplication 2") { - EmuUltraBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | import zp_reg | byte output1 @$c001 @@ -212,8 +212,15 @@ class ByteMathSuite extends FunSuite with Matchers { | } | | noinline void crash_if_bad() { + | #if ARCH_6502 | if output1 != 20 { asm { lda $bfff }} | if output2 != 27 { asm { lda $bfff }} + | #elseif ARCH_Z80 + | if output1 != 20 { asm { ld a,($bfff) }} + | if output2 != 27 { asm { ld a,($bfff) }} + | #else + | #error unsupported architecture + | #endif | } """.stripMargin){m => m.readByte(0xc002) should equal(27) @@ -238,7 +245,7 @@ class ByteMathSuite extends FunSuite with Matchers { } private def multiplyCase3(x: Int, y: Int): Unit = { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( s""" | import zp_reg | byte output @$$c000 diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index ac02bbfd..2e318791 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -76,7 +76,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Does it even work") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | word output @$c000 | void main () { @@ -99,7 +99,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if 2222 == 3333 { output -= 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Word comparison == and !=") { @@ -122,7 +122,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if a != 0 { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Word comparison <=") { @@ -143,7 +143,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if a <= c { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Word comparison <") { val src = @@ -162,7 +162,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if a < 257 { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } @@ -183,7 +183,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if c > 0 { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Word comparison >=") { @@ -206,7 +206,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if a >= 0 { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Signed comparison >=") { @@ -265,7 +265,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Multiple params for equality") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | void main () { @@ -281,7 +281,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Multiple params for inequality") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | void main () { @@ -297,7 +297,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Warnings") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | byte output @$c000 | void main () { @@ -339,7 +339,7 @@ class ComparisonSuite extends FunSuite with Matchers { | if c > 335444 { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } test("Mixed type comparison") { @@ -357,6 +357,6 @@ class ComparisonSuite extends FunSuite with Matchers { | if x < z { output += 1 } | } """.stripMargin - EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(1)) + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(src)(_.readByte(0xc000) should equal(1)) } } diff --git a/src/test/scala/millfork/test/LongTest.scala b/src/test/scala/millfork/test/LongTest.scala index 195a599f..543525f8 100644 --- a/src/test/scala/millfork/test/LongTest.scala +++ b/src/test/scala/millfork/test/LongTest.scala @@ -1,6 +1,7 @@ package millfork.test -import millfork.test.emu.EmuBenchmarkRun +import millfork.Cpu +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun} import org.scalatest.{FunSuite, Matchers} /** @@ -9,7 +10,7 @@ import org.scalatest.{FunSuite, Matchers} class LongTest extends FunSuite with Matchers { test("Long assignment") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output4 @$c000 | long output2 @$c004 @@ -28,7 +29,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long assignment 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output4 @$c000 | long output2 @$c004 @@ -51,7 +52,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long addition") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -71,7 +72,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long addition 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -85,7 +86,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long subtraction") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -105,7 +106,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long subtraction 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -119,7 +120,7 @@ class LongTest extends FunSuite with Matchers { } } test("Long subtraction 3") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -139,7 +140,7 @@ class LongTest extends FunSuite with Matchers { } test("Long AND") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output @$c000 | void main () { @@ -159,7 +160,7 @@ class LongTest extends FunSuite with Matchers { } test("Long INC/DEC") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( """ | long output0 @$c000 | long output1 @$c004 diff --git a/src/test/scala/millfork/test/SignExtensionSuite.scala b/src/test/scala/millfork/test/SignExtensionSuite.scala index d4461287..d16a8787 100644 --- a/src/test/scala/millfork/test/SignExtensionSuite.scala +++ b/src/test/scala/millfork/test/SignExtensionSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun} +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun} import org.scalatest.{FunSuite, Matchers} /** @@ -10,7 +10,7 @@ import org.scalatest.{FunSuite, Matchers} class SignExtensionSuite extends FunSuite with Matchers { test("Sbyte to Word") { - EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(""" + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(""" | word output @$c000 | void main () { | sbyte b @@ -22,7 +22,7 @@ class SignExtensionSuite extends FunSuite with Matchers { } } test("Sbyte to Word 2") { - EmuUnoptimizedRun(""" + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(""" | word output @$c000 | void main () { | output = b() @@ -30,10 +30,10 @@ class SignExtensionSuite extends FunSuite with Matchers { | sbyte b() { | return -1 | } - """.stripMargin).readWord(0xc000) should equal(0xffff) + """.stripMargin){m => m.readWord(0xc000) should equal(0xffff)} } test("Sbyte to Long") { - EmuUnoptimizedRun(""" + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(""" | long output @$c000 | void main () { | output = 421 @@ -42,11 +42,11 @@ class SignExtensionSuite extends FunSuite with Matchers { | sbyte b() { | return -1 | } - """.stripMargin).readLong(0xc000) should equal(420) + """.stripMargin){m => m.readLong(0xc000) should equal(420)} } test("Optimize pointless sign extension") { - EmuBenchmarkRun(""" + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)(""" | array output [10] @$c000 | word w | void main () { diff --git a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala index 606b652f..02955291 100644 --- a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala +++ b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala @@ -7,7 +7,7 @@ import millfork.output.MemoryBank * @author Karol Stasiak */ object EmuBenchmarkRun { - def apply(source: String)(verifier: MemoryBank => Unit) = { + def apply(source: String)(verifier: MemoryBank => Unit): Unit = { val (Timings(t0, _), m0) = EmuUnoptimizedRun.apply2(source) val (Timings(t1, _), m1) = EmuOptimizedRun.apply2(source) val (Timings(t2, _), m2) = EmuOptimizedInlinedRun.apply2(source) @@ -26,16 +26,21 @@ object EmuBenchmarkRun { } object EmuZ80BenchmarkRun { - def apply(source: String)(verifier: MemoryBank => Unit) = { + def apply(source: String)(verifier: MemoryBank => Unit): Unit = { val (Timings(t0, _), m0) = EmuUnoptimizedZ80Run.apply2(source) val (Timings(t1, _), m1) = EmuOptimizedZ80Run.apply2(source) + val (Timings(t2, _), m2) = EmuOptimizedInlinedZ80Run.apply2(source) println(f"Before optimization: $t0%7d") println(f"After optimization: $t1%7d") + println(f"After inlining: $t2%7d") println(f"Gain: ${(100L * (t0 - t1) / t0.toDouble).round}%7d%%") + println(f"Gain with inlining: ${(100L * (t0 - t2) / t0.toDouble).round}%7d%%") println(f"Running unoptimized") verifier(m0) println(f"Running optimized") verifier(m1) + println(f"Running optimized inlined") + verifier(m2) } } diff --git a/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala b/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala index 4dde2c22..055a0dc4 100644 --- a/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala +++ b/src/test/scala/millfork/test/emu/EmuOptimizedInlinedRun.scala @@ -1,6 +1,7 @@ package millfork.test.emu import millfork.assembly.mos.opt.{LaterOptimizations, ZeropageRegisterOptimizations} +import millfork.assembly.z80.opt.Z80OptimizationPresets import millfork.{Cpu, OptimizationPresets} /** @@ -22,4 +23,8 @@ object EmuOptimizedInlinedRun extends EmuRun( } +object EmuOptimizedInlinedZ80Run extends EmuZ80Run(Cpu.Z80, OptimizationPresets.NodeOpt, Z80OptimizationPresets.Good) { + override def inline: Boolean = true +} + diff --git a/src/test/scala/millfork/test/emu/EmuZ80Run.scala b/src/test/scala/millfork/test/emu/EmuZ80Run.scala index bf01c9a8..c6aee429 100644 --- a/src/test/scala/millfork/test/emu/EmuZ80Run.scala +++ b/src/test/scala/millfork/test/emu/EmuZ80Run.scala @@ -19,6 +19,7 @@ import org.scalatest.Matchers * @author Karol Stasiak */ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], assemblyOptimizations: List[AssemblyOptimization[ZLine]]) extends Matchers { + def inline: Boolean = false private val TooManyCycles: Long = 1000000 @@ -27,7 +28,9 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio Console.err.flush() println(source) val platform = EmuPlatform.get(cpu) - val extraFlags = Map(CompilationFlag.LenientTextEncoding -> true) + val extraFlags = Map( + CompilationFlag.InlineFunctions -> this.inline, + CompilationFlag.LenientTextEncoding -> true) val options = CompilationOptions(platform, millfork.Cpu.defaultFlags(cpu).map(_ -> true).toMap ++ extraFlags, None, 0) ErrorReporting.hasErrors = false ErrorReporting.verbosity = 999 @@ -68,7 +71,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio val assembler = new Z80Assembler(program, env2, platform) val output = assembler.assemble(callGraph, assemblyOptimizations, options) println(";;; compiled: -----------------") - output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).filterNot(_.contains("; DISCARD_")).foreach(println) + output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).filterNot(_.contains("////; DISCARD_")).foreach(println) println(";;; ---------------------------") assembler.labelMap.foreach { case (l, addr) => println(f"$l%-15s $$$addr%04x") }