diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index d9f49458..59977d9a 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -40,6 +40,8 @@ where `11111` is a sequential number and `xx` is the type: * `lj` – extra labels generated when converting invalid short jumps to long jumps +* `me` – start of a `for` loop doing bulk memory operations + * `no` – nonet to word extension caused by the `nonet` operator * `od` – end of a `do-while` statement diff --git a/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala new file mode 100644 index 00000000..cd19a955 --- /dev/null +++ b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala @@ -0,0 +1,399 @@ +package millfork.compiler.z80 + +import millfork.CompilationFlag +import millfork.assembly.z80._ +import millfork.compiler.CompilationContext +import millfork.env._ +import millfork.node._ +import millfork.assembly.z80.ZOpcode._ + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +object Z80BulkMemoryOperations { + import Z80StatementCompiler.compileForStatement + + def compileMemcpy(ctx: CompilationContext, target: IndexedExpression, source: IndexedExpression, f: ForStatement): List[ZLine] = { + val sourceOffset = removeVariableOnce(f.variable, source.index).getOrElse(return compileForStatement(ctx, f)) + if (!sourceOffset.isPure) return compileForStatement(ctx, f) + val sourceIndexExpression = SumExpression(List(false -> sourceOffset, false -> f.start), decimal = false) + val calculateSource = Z80ExpressionCompiler.calculateAddressToHL(ctx, IndexedExpression(source.name, sourceIndexExpression)) + compileMemoryBulk(ctx, target, f, + useDEForTarget = true, + preferDecreasing = false, + _ => calculateSource -> Nil, + next => List( + ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), + ZLine.register(next, ZRegister.HL) + ), + decreasing => Some(if (decreasing) LDDR else LDIR) + ) + } + + + def compileMemset(ctx: CompilationContext, target: IndexedExpression, source: Expression, f: ForStatement): List[ZLine] = { + val loadA = Z80ExpressionCompiler.stashHLIfChanged(Z80ExpressionCompiler.compileToA(ctx, source)) :+ ZLine.ld8(ZRegister.MEM_HL, ZRegister.A) + compileMemoryBulk(ctx, target, f, + useDEForTarget = false, + preferDecreasing = false, + _ => Nil -> Nil, + _ => loadA, + _ => None + ) + } + + def compileMemtransform(ctx: CompilationContext, target: IndexedExpression, operator: String, source: Expression, f: ForStatement): List[ZLine] = { + val c = determineExtraLoopRegister(ctx, f, source.containsVariable(f.variable)) + val load = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator, source, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + import scala.util.control.Breaks._ + breakable{ + return compileMemoryBulk(ctx, target, f, + useDEForTarget = false, + preferDecreasing = true, + isSmall => if (isSmall) Nil -> c.initC else break, + _ => load ++ c.nextC, + _ => None + ) + } + compileForStatement(ctx, f) + } + + def compileMemtransform2(ctx: CompilationContext, + target1: IndexedExpression, operator1: String, source1: Expression, + target2: IndexedExpression, operator2: String, source2: Expression, + f: ForStatement): List[ZLine] = { + import scala.util.control.Breaks._ + val c = determineExtraLoopRegister(ctx, f, source1.containsVariable(f.variable) || source2.containsVariable(f.variable)) + val target1Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)) + val target2Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)) + val target1IndexExpression = if (c.countDownDespiteSyntax) { + SumExpression(List(false -> target1Offset, false -> f.end, true -> LiteralExpression(1, 1)), decimal = false) + } else { + SumExpression(List(false -> target1Offset, false -> f.start), decimal = false) + } + val target2IndexExpression = if (c.countDownDespiteSyntax) { + SumExpression(List(false -> target2Offset, false -> f.end, true -> LiteralExpression(1, 1)), decimal = false) + } else { + SumExpression(List(false -> target2Offset, false -> f.start), decimal = false) + } + val fused = target1.name == target2.name && ((ctx.env.eval(target1Offset), ctx.env.eval(target2Offset)) match { + case (Some(a), Some(b)) => a == b + case _ => false + }) + if (fused) { + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val loads = load1 ++ load2 + breakable{ + return compileMemoryBulk(ctx, target1, f, + useDEForTarget = false, + preferDecreasing = true, + isSmall => if (isSmall) Nil -> c.initC else break, + _ => loads ++ c.nextC, + _ => None + ) + } + } else { + val goodness1 = goodnessForHL(ctx, operator1, source1) + val goodness2 = goodnessForHL(ctx, operator2, source2) + val loads = if (goodness1 <= goodness2) { + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + load1 ++ load2 + } else { + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + load1 ++ load2 + } + val targetForDE = if (goodness1 <= goodness2) target1 else target2 + val targetForHL = if (goodness1 <= goodness2) target2 else target1 + val targetForHLIndexExpression = if (goodness1 <= goodness2) target2IndexExpression else target1IndexExpression + breakable{ + return compileMemoryBulk(ctx, targetForDE, f, + useDEForTarget = true, + preferDecreasing = true, + isSmall => if (isSmall) { + Z80ExpressionCompiler.calculateAddressToHL(ctx, IndexedExpression(targetForHL.name, targetForHLIndexExpression)) -> c.initC + } else break, + next => loads ++ (c.nextC :+ ZLine.register(next, ZRegister.HL)), + _ => None + ) + } + } + compileForStatement(ctx, f) + } + + private case class ExtraLoopRegister(loopRegister: ZRegister.Value, initC: List[ZLine], nextC: List[ZLine], countDownDespiteSyntax: Boolean) + + private def determineExtraLoopRegister(ctx: CompilationContext, f: ForStatement, readsLoopVariable: Boolean): ExtraLoopRegister = { + val useC = readsLoopVariable && (f.direction match { + case ForDirection.Until | ForDirection.To => true + case ForDirection.DownTo => ctx.env.eval(f.end) match { + case Some(NumericConstant(1, _)) => false + case _ => true + } + case ForDirection.ParallelUntil | ForDirection.ParallelTo => ctx.env.eval(f.start) match { + case Some(NumericConstant(1, _)) => false + case _ => true + } + }) + val loopRegister = if (useC) ZRegister.C else ZRegister.B + val countDown = f.direction == ForDirection.ParallelTo || f.direction == ForDirection.ParallelUntil || f.direction == ForDirection.DownTo + val countDownDespiteSyntax = f.direction == ForDirection.ParallelTo || f.direction == ForDirection.ParallelUntil + val initC = if (useC) Z80ExpressionCompiler.compileToA(ctx, f.direction match { + case ForDirection.ParallelTo => f.end + case ForDirection.ParallelUntil => SumExpression(List(false -> f.end, true -> LiteralExpression(1, 1)), decimal = false) + case _ => f.start + }) :+ ZLine.ld8(ZRegister.C, ZRegister.A) else Nil + val nextC = if (useC) List(ZLine.register(if (countDown) DEC else INC, ZRegister.C)) else Nil + ExtraLoopRegister(loopRegister, initC, nextC, countDownDespiteSyntax) + } + + private def goodnessForHL(ctx: CompilationContext, operator: String, source: Expression): Int = operator match { + case "<<=" | ">>=" => + if (ctx.options.flag(CompilationFlag.EmitExtended80Opcodes)) ctx.env.eval(source) match { + case Some(NumericConstant(n, _)) => + if (ctx.options.flag(CompilationFlag.OptimizeForSize)) { + 2 + } else { + // SLA (HL) = 15 cycles + // SLA A = 8 cycles + // LD A,(HL) + LD (HL),A = 14 cycles + (14 - 7 * (n max 0).toInt).max(0) + } + case _ => 0 + } else 0 + case "+=" | "-=" => ctx.env.eval(source) match { + case Some(NumericConstant(1, _)) => + if (ctx.options.flag(CompilationFlag.OptimizeForSize)) { + 2 + } else { + // INC (HL) = 11 cycles + // INC A = 4 cycles + // LD A,(HL) + LD (HL),A = 14 cycles + 7 + } + case Some(NumericConstant(2, _)) => + if (ctx.options.flag(CompilationFlag.OptimizeForSize)) { + 2 + } else { + // INC (HL) = 11 cycles + // ADD # = 7 cycles + // LD A,(HL) + LD (HL),A = 14 cycles + 0 + } + case _ => 0 + } + case "=" => ctx.env.eval(source) match { + case Some(_) => + if (ctx.options.flag(CompilationFlag.OptimizeForSize)) { + 1 + } else { + // LD (HL),# = 11 cycles + // LD A,# = 4 cycles + // LD (HL),A = 7 cycles + // so we don't save cycles, but we do save one byte + 1 + } + case _ => 0 + } + case _ => 0 + } + + private def buildMemtransformLoader(ctx: CompilationContext, element: ZRegister.Value, loopVariable: String, operator: String, source: Expression, loopRegister: ZRegister.Value): Option[List[ZLine]] = { + val env = ctx.env + if (operator == "=") { + source match { + case VariableExpression(n) if n == loopVariable => + if (element == ZRegister.MEM_HL) Some(List(ZLine.ld8(ZRegister.MEM_HL, loopRegister))) + else Some(List(ZLine.ld8(ZRegister.A, loopRegister), ZLine.ld8(element, ZRegister.A))) + case _ => env.eval(source) map { c => + if (element == ZRegister.MEM_HL) List(ZLine.ldImm8(ZRegister.MEM_HL, c)) + else List(ZLine.ldImm8(ZRegister.A, c), ZLine.ld8(element, ZRegister.A)) + } + } + } else { + val (operation, daa, shift) = operator match { + case "+=" => (ZOpcode.ADD, false, None) + case "+'=" => (ZOpcode.ADD, true, None) + case "-=" => (ZOpcode.SUB, false, None) + case "-'=" => (ZOpcode.SUB, true, None) + case "|=" => (ZOpcode.OR, false, None) + case "&=" => (ZOpcode.AND, false, None) + case "^=" => (ZOpcode.XOR, false, None) + case "<<=" => (ZOpcode.SLA, false, Some(RLC, 0xfe)) + case ">>=" => (ZOpcode.SRL, false, Some(RRC, 0x7f)) + case _ => return None + } + shift match { + case Some((nonZ80, mask)) => + Some(env.eval(source) match { + case Some(NumericConstant(n, _)) => + if (n <= 0) Nil else { + if (ctx.options.flag(CompilationFlag.EmitExtended80Opcodes)) { + if (element == ZRegister.MEM_HL && n <= 2) { + List.fill(n.toInt)(ZLine.register(operation, ZRegister.MEM_HL)) + } else { + val builder = mutable.ListBuffer[ZLine]() + builder += ZLine.ld8(ZRegister.A, element) + for (_ <- 0 until n.toInt) { + builder += ZLine.register(operation, ZRegister.A) + } + builder += ZLine.ld8(element, ZRegister.A) + builder.toList + } + } else { + val builder = mutable.ListBuffer[ZLine]() + builder += ZLine.ld8(ZRegister.A, element) + for (_ <- 0 until n.toInt) { + builder += ZLine.register(nonZ80, ZRegister.A) + builder += ZLine.imm8(AND, mask) + } + builder += ZLine.ld8(element, ZRegister.A) + builder.toList + } + } + case _ => return None + }) + case None => + val mod = source match { + case VariableExpression(n) if n == loopVariable => + List(ZLine.register(operation, loopRegister)) + case _ => env.eval(source) match { + case Some(NumericConstant(1, _)) if operator == "+=" && element == ZRegister.MEM_HL => + return Some(List(ZLine.register(INC, ZRegister.MEM_HL))) + case Some(NumericConstant(1, _)) if operator == "-=" && element == ZRegister.MEM_HL => + return Some(List(ZLine.register(DEC, ZRegister.MEM_HL))) + case Some(NumericConstant(2, _)) if operator == "+=" && element == ZRegister.MEM_HL && ctx.options.flag(CompilationFlag.OptimizeForSize) => + return Some(List(ZLine.register(INC, ZRegister.MEM_HL), ZLine.register(INC, ZRegister.MEM_HL))) + case Some(NumericConstant(2, _)) if operator == "-=" && element == ZRegister.MEM_HL && ctx.options.flag(CompilationFlag.OptimizeForSize) => + return Some(List(ZLine.register(DEC, ZRegister.MEM_HL), ZLine.register(DEC, ZRegister.MEM_HL))) + case Some(c) => + if (daa) { + List(ZLine.imm8(operation, c), ZLine.implied(DAA)) + } else { + List(ZLine.imm8(operation, c)) + } + case _ => return None + } + } + Some( + ZLine.ld8(ZRegister.A, element) :: (mod :+ ZLine.ld8(element, ZRegister.A)) + ) + } + } + } + + /** + * + * @param ctx compilation context + * @param target target indexed expression + * @param f original for statement + * @param useDEForTarget use DE instead of HL for target + * @param extraAddressCalculations extra calculations to perform before the loop, before and after POP BC (parameter: is count small) + * @param loadA byte value calculation (parameter: INC_16 or DEC_16) + * @param z80Bulk Z80 opcode for faster operation (parameter: is decreasing) + * @return + */ + def compileMemoryBulk(ctx: CompilationContext, + target: IndexedExpression, + f: ForStatement, + useDEForTarget: Boolean, + preferDecreasing: Boolean, + extraAddressCalculations: Boolean => (List[ZLine], List[ZLine]), + loadA: ZOpcode.Value => List[ZLine], + z80Bulk: Boolean => Option[ZOpcode.Value]): List[ZLine] = { + val one = LiteralExpression(1, 1) + val targetOffset = removeVariableOnce(f.variable, target.index).getOrElse(return compileForStatement(ctx, f)) + if (!targetOffset.isPure) return compileForStatement(ctx, f) + val indexVariableSize = ctx.env.get[Variable](f.variable).typ.size + val wrapper = createForLoopPreconditioningIfStatement(ctx, f) + val decreasingDespiteSyntax = preferDecreasing && (f.direction == ForDirection.ParallelTo || f.direction == ForDirection.ParallelUntil) + val decreasing = f.direction == ForDirection.DownTo || decreasingDespiteSyntax + val plusOne = f.direction == ForDirection.To || f.direction == ForDirection.DownTo || f.direction == ForDirection.ParallelTo + val byteCountExpression = + if (f.direction == ForDirection.DownTo) SumExpression(List(false -> f.start, false -> one, true -> f.end), decimal = false) + else if (plusOne) SumExpression(List(false -> f.end, false -> one, true -> f.start), decimal = false) + else SumExpression(List(false -> f.end, true -> f.start), decimal = false) + val targetIndexExpression = if (decreasingDespiteSyntax) { + SumExpression(List(false -> targetOffset, false -> f.end, true -> one), decimal = false) + } else { + SumExpression(List(false -> targetOffset, false -> f.start), decimal = false) + } + val ldr = z80Bulk(decreasing) + val smallCount = indexVariableSize == 1 && (ldr.isEmpty || !ctx.options.flag(CompilationFlag.EmitZ80Opcodes)) + val calculateByteCount = if (smallCount) { + Z80ExpressionCompiler.compileToA(ctx, byteCountExpression) ++ + List(ZLine.ld8(ZRegister.B, ZRegister.A)) + } else { + Z80ExpressionCompiler.compileToHL(ctx, byteCountExpression) ++ + List(ZLine.ld8(ZRegister.B, ZRegister.H), ZLine.ld8(ZRegister.C, ZRegister.L)) + } + val next = if (decreasing) DEC_16 else INC_16 + val calculateSourceValue = loadA(next) + val calculateTargetAddress = Z80ExpressionCompiler.calculateAddressToHL(ctx, IndexedExpression(target.name, targetIndexExpression)) + val extraInitializationPair = extraAddressCalculations(smallCount) + // TODO: figure the optimal compilation order + val loading = if (useDEForTarget) { + calculateByteCount ++ + Z80ExpressionCompiler.stashBCIfChanged(calculateTargetAddress ++ List(ZLine.ld8(ZRegister.D, ZRegister.H), ZLine.ld8(ZRegister.E, ZRegister.L))) ++ + Z80ExpressionCompiler.stashBCIfChanged(Z80ExpressionCompiler.stashDEIfChanged(extraInitializationPair._1)) ++ + Z80ExpressionCompiler.stashHLIfChanged(Z80ExpressionCompiler.stashDEIfChanged(extraInitializationPair._2)) + } else { + calculateByteCount ++ + Z80ExpressionCompiler.stashBCIfChanged(calculateTargetAddress) ++ + Z80ExpressionCompiler.stashBCIfChanged(Z80ExpressionCompiler.stashHLIfChanged(extraInitializationPair._1)) ++ + Z80ExpressionCompiler.stashHLIfChanged(extraInitializationPair._1) + } + + val label = Z80Compiler.nextLabel("me") + val body = if (ldr.isDefined && ctx.options.flag(CompilationFlag.EmitZ80Opcodes)) { + List(ZLine.implied(ldr.get)) + } else { + ZLine.label(label) :: calculateSourceValue ++ (if (smallCount) { + List( + ZLine.register(next, if (useDEForTarget) ZRegister.DE else ZRegister.HL), + ZLine.djnz(label) + ) + } else { + List( + ZLine.register(next, if (useDEForTarget) ZRegister.DE else ZRegister.HL), + ZLine.register(DEC_16, ZRegister.BC), + ZLine.ld8(ZRegister.A, ZRegister.C), + ZLine.register(OR, ZRegister.B), + ZLine.jump(label, IfFlagSet(ZFlag.Z)) + ) + }) + } + wrapper.flatMap(l => if (l.opcode == NOP) loading ++ body else List(l)) + } + + private def createForLoopPreconditioningIfStatement(ctx: CompilationContext, f: ForStatement): List[ZLine] = { + val operator = f.direction match { + case ForDirection.To | ForDirection.ParallelTo => "<=" + case ForDirection.DownTo => ">=" + case ForDirection.Until | ForDirection.ParallelUntil => "<" + } + Z80StatementCompiler.compile(ctx, IfStatement( + FunctionCallExpression(operator, List(f.start, f.end)), + List(Z80AssemblyStatement(ZOpcode.NOP, NoRegisters, LiteralExpression(0, 1), elidable = false)), + Nil)) + } + + private def removeVariableOnce(variable: String, expr: Expression): Option[Expression] = { + expr match { + case VariableExpression(i) => if (i == variable) Some(LiteralExpression(0, 1)) else None + case SumExpression(exprs, false) => + if (exprs.count(_._2.containsVariable(variable)) == 1) { + Some(SumExpression(exprs.map { + case (false, e) => false -> (if (e.containsVariable(variable)) removeVariableOnce(variable, e).getOrElse(return None) else e) + case (true, e) => if (e.containsVariable(variable)) return None else true -> e + }, decimal = false)) + } else None + case _ => None + } + } + +} diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala index ad509971..f31635bd 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -2,6 +2,7 @@ package millfork.compiler.z80 import millfork.assembly.z80._ import millfork.compiler._ +import millfork.env.NumericConstant import millfork.node.{Expression, ZRegister} /** @@ -12,6 +13,23 @@ object Z80Comparisons { import ComparisonType._ def compile8BitComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { + (ctx.env.eval(l), ctx.env.eval(r)) match { + case (Some(NumericConstant(lc, _)), Some(NumericConstant(rc, _))) => + val constantCondition = compType match { + case Equal => lc == rc + case NotEqual => lc != rc + case GreaterSigned | GreaterUnsigned => lc > rc + case LessOrEqualSigned | LessOrEqualUnsigned => lc <= rc + case GreaterOrEqualSigned | GreaterOrEqualUnsigned=> lc >= rc + case LessSigned | LessUnsigned => lc < rc + } + return branches match { + case BranchIfFalse(b) => if (!constantCondition) List(ZLine.jump(b)) else Nil + case BranchIfTrue(b) => if (constantCondition) List(ZLine.jump(b)) else Nil + case _ => Nil + } + case _ => + } compType match { case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => return compile8BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) @@ -22,6 +40,7 @@ object Z80Comparisons { List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ Z80ExpressionCompiler.stashDEIfChanged(Z80ExpressionCompiler.compileToA(ctx, l)) ++ List(ZLine.register(ZOpcode.CP, ZRegister.E)) + if (branches == NoBranching) return calculateFlags val jump = (compType, branches) match { case (Equal, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) case (Equal, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index e289d900..de8591e1 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -1,9 +1,9 @@ package millfork.compiler.z80 import millfork.assembly.BranchingOpcodeMapping -import millfork.assembly.z80.ZLine +import millfork.assembly.z80._ import millfork.compiler.{AbstractExpressionCompiler, AbstractStatementCompiler, BranchSpec, CompilationContext} -import millfork.env.{BooleanType, ConstantBooleanType, Label, MacroFunction} +import millfork.env._ import millfork.node._ import millfork.assembly.z80.ZOpcode._ import millfork.error.ErrorReporting @@ -74,11 +74,60 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { compileWhileStatement(ctx, s) case s: DoWhileStatement => compileDoWhileStatement(ctx, s) - case f:ForStatement => - compileForStatement(ctx,f) - case s:BreakStatement => + + case f@ForStatement(_, _, _, _, List(Assignment(target: IndexedExpression, source: IndexedExpression))) => + Z80BulkMemoryOperations.compileMemcpy(ctx, target, source, f) + + case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) => + Z80BulkMemoryOperations.compileMemset(ctx, target, source, f) + + case f@ForStatement(variable, _, _, _, List(ExpressionStatement(FunctionCallExpression( + operator@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), + List(target: IndexedExpression, source: Expression) + )))) => + Z80BulkMemoryOperations.compileMemtransform(ctx, target, operator, source, f) + + case f@ForStatement(variable, _, _, _, List( + ExpressionStatement(FunctionCallExpression( + operator1@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), + List(target1: IndexedExpression, source1: Expression) + )), + ExpressionStatement(FunctionCallExpression( + operator2@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), + List(target2: IndexedExpression, source2: Expression) + )) + )) => + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, operator2, source2, f) + + case f@ForStatement(variable, _, _, _, List( + Assignment(target1: IndexedExpression, source1: Expression), + ExpressionStatement(FunctionCallExpression( + operator2@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), + List(target2: IndexedExpression, source2: Expression) + )) + )) => + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, operator2, source2, f) + + case f@ForStatement(variable, _, _, _, List( + ExpressionStatement(FunctionCallExpression( + operator1@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), + List(target1: IndexedExpression, source1: Expression) + )), + Assignment(target2: IndexedExpression, source2: Expression) + )) => + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, "=", source2, f) + + case f@ForStatement(variable, _, _, _, List( + Assignment(target1: IndexedExpression, source1: Expression), + Assignment(target2: IndexedExpression, source2: Expression) + )) => + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, "=", source2, f) + + case f: ForStatement => + compileForStatement(ctx, f) + case s: BreakStatement => compileBreakStatement(ctx, s) - case s:ContinueStatement => + case s: ContinueStatement => compileContinueStatement(ctx, s) case ExpressionStatement(e@FunctionCallExpression(name, params)) => ctx.env.lookupFunction(name, params.map(p => Z80ExpressionCompiler.getExpressionType(ctx, p) -> p)) match { @@ -90,14 +139,22 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { } case ExpressionStatement(e) => Z80ExpressionCompiler.compile(ctx, e, ZExpressionTarget.NOTHING) + case Z80AssemblyStatement(op, reg, expression, elidable) => + val param = ctx.env.evalForAsm(expression) match { + case Some(v) => v + case None => + ErrorReporting.error("Inlining failed due to non-constant things", expression.position) + Constant.Zero + } + List(ZLine(op, reg, param, elidable)) } } def labelChunk(labelName: String) = List(ZLine.label(Label(labelName))) - def jmpChunk(label: Label) = List(ZLine.jump(label)) + def jmpChunk(label: Label) = List(ZLine.jump(label)) - def branchChunk(opcode: BranchingOpcodeMapping, labelName: String) = List(ZLine.jump(Label(labelName), opcode.z80Flags)) + def branchChunk(opcode: BranchingOpcodeMapping, labelName: String) = List(ZLine.jump(Label(labelName), opcode.z80Flags)) def areBlocksLarge(blocks: List[ZLine]*): Boolean = false diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 4a101711..2d118161 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -1,6 +1,7 @@ package millfork.node import millfork.assembly.mos.{AddrMode, Opcode} +import millfork.assembly.z80.{ZOpcode, ZRegisters} import millfork.env.{Constant, ParamPassingConvention} case class Position(filename: String, line: Int, column: Int, cursor: Int) @@ -20,24 +21,34 @@ object Node { sealed trait Expression extends Node { def replaceVariable(variable: String, actualParam: Expression): Expression + def containsVariable(variable: String): Boolean + def isPure: Boolean } case class ConstantArrayElementExpression(constant: Constant) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this + override def containsVariable(variable: String): Boolean = false + override def isPure: Boolean = true } case class LiteralExpression(value: Long, requiredSize: Int) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this + override def containsVariable(variable: String): Boolean = false + override def isPure: Boolean = true } case class BooleanLiteralExpression(value: Boolean) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = this + override def containsVariable(variable: String): Boolean = false + override def isPure: Boolean = true } sealed trait LhsExpression extends Expression case object BlackHoleExpression extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): LhsExpression = this + override def containsVariable(variable: String): Boolean = false + override def isPure: Boolean = true } case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsExpression { @@ -45,11 +56,15 @@ case class SeparateBytesExpression(hi: Expression, lo: Expression) extends LhsEx SeparateBytesExpression( hi.replaceVariable(variable, actualParam), lo.replaceVariable(variable, actualParam)) + override def containsVariable(variable: String): Boolean = hi.containsVariable(variable) || lo.containsVariable(variable) + override def isPure: Boolean = hi.isPure && lo.isPure } case class SumExpression(expressions: List[(Boolean, Expression)], decimal: Boolean) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = SumExpression(expressions.map { case (n, e) => n -> e.replaceVariable(variable, actualParam) }, decimal) + override def containsVariable(variable: String): Boolean = expressions.exists(_._2.containsVariable(variable)) + override def isPure: Boolean = expressions.forall(_._2.isPure) } case class FunctionCallExpression(functionName: String, expressions: List[Expression]) extends Expression { @@ -57,11 +72,15 @@ case class FunctionCallExpression(functionName: String, expressions: List[Expres FunctionCallExpression(functionName, expressions.map { _.replaceVariable(variable, actualParam) }) + override def containsVariable(variable: String): Boolean = expressions.exists(_.containsVariable(variable)) + override def isPure: Boolean = false // TODO } case class HalfWordExpression(expression: Expression, hiByte: Boolean) extends Expression { override def replaceVariable(variable: String, actualParam: Expression): Expression = HalfWordExpression(expression.replaceVariable(variable, actualParam), hiByte) + override def containsVariable(variable: String): Boolean = expression.containsVariable(variable) + override def isPure: Boolean = expression.isPure } sealed class NiceFunctionProperty(override val toString: String) @@ -111,6 +130,8 @@ object ZRegister extends Enumeration { case class VariableExpression(name: String) extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): Expression = if (name == variable) actualParam else this + override def containsVariable(variable: String): Boolean = name == variable + override def isPure: Boolean = true } case class IndexedExpression(name: String, index: Expression) extends LhsExpression { @@ -121,6 +142,8 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress case _ => ??? // TODO } } else IndexedExpression(name, index.replaceVariable(variable, actualParam)) + override def containsVariable(variable: String): Boolean = name == variable || index.containsVariable(variable) + override def isPure: Boolean = index.isPure } sealed trait Statement extends Node { @@ -271,6 +294,10 @@ case class MosAssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value, override def getAllExpressions: List[Expression] = List(expression) } +case class Z80AssemblyStatement(opcode: ZOpcode.Value, registers: ZRegisters, expression: Expression, elidable: Boolean) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = List(expression) +} + case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions) diff --git a/src/main/scala/millfork/output/CompiledMemory.scala b/src/main/scala/millfork/output/CompiledMemory.scala index 93f6a75e..79876da6 100644 --- a/src/main/scala/millfork/output/CompiledMemory.scala +++ b/src/main/scala/millfork/output/CompiledMemory.scala @@ -1,5 +1,7 @@ package millfork.output +import millfork.error.ErrorReporting + import scala.collection.mutable /** @@ -27,4 +29,8 @@ class MemoryBank { val writeable = Array.fill(1 << 16)(false) var start: Int = 0 var end: Int = 0 + + def dump(startAddr: Int, count: Int)(dumper: String => Any): Unit = { + (0 until count).map(i => output(i + startAddr)).grouped(16).zipWithIndex.map { case (c, i) => f"$i%04X: " + c.map(i => f"$i%02x").mkString(" ") }.foreach(dumper) + } } diff --git a/src/test/scala/millfork/test/ForLoopSuite.scala b/src/test/scala/millfork/test/ForLoopSuite.scala index 0238fa75..95d1cb66 100644 --- a/src/test/scala/millfork/test/ForLoopSuite.scala +++ b/src/test/scala/millfork/test/ForLoopSuite.scala @@ -1,6 +1,8 @@ package millfork.test -import millfork.test.emu.{EmuBenchmarkRun, EmuUnoptimizedRun} +import millfork.CpuFamily +import millfork.error.ErrorReporting +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuUnoptimizedRun} import org.scalatest.{FunSuite, Matchers} /** @@ -56,7 +58,7 @@ class ForLoopSuite extends FunSuite with Matchers { } test("For-downto 2") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( """ | array output [55] @$c000 | void main () { @@ -151,4 +153,59 @@ class ForLoopSuite extends FunSuite with Matchers { | void _panic(){while(true){}} """.stripMargin) } + + test("Memcpy") { + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( + """ + | array output[5]@$c001 + | array input = [0,1,4,9,16,25,36,49] + | void main () { + | byte i + | for i,0,until,output.length { + | output[i] = input[i+1] + | } + | } + | void _panic(){while(true){}} + """.stripMargin){ m=> + m.readByte(0xc001) should equal (1) + m.readByte(0xc005) should equal (25) + } + } + + test("Various bulk operations") { + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( + """ + | array output0[5]@$c000 + | array output1[5]@$c010 + | array output2[5]@$c020 + | array input = [0,1,4,9,16,25,36,49] + | void main () { + | byte i + | for i,0,until,5 { + | output0[i] = 0 + | } + | for i,0,paralleluntil,5 { + | output1[i] = 1 + | output2[i] = i + | } + | for i,4,downto,0 { + | output0[i] +'= 4 + | output2[i] <<= 1 + | } + | for i,0,to,4 { + | output1[i] ^= i + | output1[i] += 5 + | } + | } + """.stripMargin){ m=> + m.dump(0xc000, 5)(ErrorReporting.debug(_)) + m.dump(0xc010, 5)(ErrorReporting.debug(_)) + m.dump(0xc020, 5)(ErrorReporting.debug(_)) + m.readByte(0xc001) should equal (4) + m.readByte(0xc023) should equal (6) + m.readByte(0xc013) should equal (7) + } + } + + } diff --git a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala index ff9cb85b..e21b9cc3 100644 --- a/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala +++ b/src/test/scala/millfork/test/emu/EmuBenchmarkRun.scala @@ -41,6 +41,9 @@ object EmuI80BenchmarkRun { object EmuCrossPlatformBenchmarkRun { def apply(platforms: CpuFamily.Value*)(source: String)(verifier: MemoryBank => Unit): Unit = { + if (platforms.isEmpty) { + throw new RuntimeException("Dude, test at least one platform") + } if (platforms.contains(CpuFamily.M6502)) { EmuBenchmarkRun.apply(source)(verifier) } diff --git a/src/test/scala/millfork/test/emu/EmuZ80Run.scala b/src/test/scala/millfork/test/emu/EmuZ80Run.scala index a6227236..d04a38a9 100644 --- a/src/test/scala/millfork/test/emu/EmuZ80Run.scala +++ b/src/test/scala/millfork/test/emu/EmuZ80Run.scala @@ -108,7 +108,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio cpu.resetTStates() while (!cpu.getHalt) { cpu.executeOneInstruction() - dump(cpu) +// dump(cpu) cpu.getTStates should be < TooManyCycles } val tStates = cpu.getTStates