diff --git a/docs/abi/generated-labels.md b/docs/abi/generated-labels.md index 0ef173e8..a13b6ee7 100644 --- a/docs/abi/generated-labels.md +++ b/docs/abi/generated-labels.md @@ -34,6 +34,8 @@ where `11111` is a sequential number and `xx` is the type: * `fi` – end of an `if` statement +* `fe` – body of an `for` statement over a list + * `he` – beginning of the body of a `while` statement * `in` – increment for larger types diff --git a/docs/lang/syntax.md b/docs/lang/syntax.md index a6f3ed73..b70118e7 100644 --- a/docs/lang/syntax.md +++ b/docs/lang/syntax.md @@ -255,7 +255,11 @@ do { Syntax: ``` -for ,,, { +for , , , { +} +for : { +} +for : [ ] { } ``` @@ -277,7 +281,15 @@ for ,,, { * `paralleluntil` – the same as `until`, but the iterations may be executed in any order There is no `paralleldownto`, because it would do the same as `parallelto`. - + +* `` – traverse indices of an array, from 0 to length–1 + +* `` – traverse enum constants of given type, in arbitrary order + +* `` – traverse from 0 to `expression` – 1 + +* `` – traverse every value in the list + ### `break` and `continue` statements Syntax: diff --git a/docs/stdlib/string.md b/docs/stdlib/string.md new file mode 100644 index 00000000..e69de29b diff --git a/src/main/scala/millfork/compiler/AbstractCompiler.scala b/src/main/scala/millfork/compiler/AbstractCompiler.scala index 5aeec494..c8fee29c 100644 --- a/src/main/scala/millfork/compiler/AbstractCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractCompiler.scala @@ -7,4 +7,5 @@ import millfork.assembly.AbstractCode */ abstract class AbstractCompiler[T <: AbstractCode] { def compile(ctx: CompilationContext): List[T] + def packHalves(tuple: (List[T], List[T])): List[T] = tuple._1 ++ tuple._2 } diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 3190976e..1fb4dc4f 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -1,9 +1,8 @@ package millfork.compiler -import millfork.CpuFamily +import millfork.CompilationFlag import millfork.assembly.{AbstractCode, BranchingOpcodeMapping} import millfork.env._ -import millfork.error.ConsoleLogger import millfork.node._ /** @@ -11,11 +10,12 @@ import millfork.node._ */ abstract class AbstractStatementCompiler[T <: AbstractCode] { - def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): List[T] = { - statements.flatMap(s => compile(ctx, s)) + def compile(ctx: CompilationContext, statements: List[ExecutableStatement]): (List[T], List[T]) = { + val chunks = statements.map(s => compile(ctx, s)) + chunks.flatMap(_._1) -> chunks.flatMap(_._2) } - def compile(ctx: CompilationContext, statement: ExecutableStatement): List[T] + def compile(ctx: CompilationContext, statement: ExecutableStatement): (List[T], List[T]) def labelChunk(labelName: String): List[T] @@ -27,18 +27,24 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { def compileExpressionForBranching(ctx: CompilationContext, expr: Expression, branching: BranchSpec): List[T] + def replaceLabel(ctx: CompilationContext, line: T, from: String, to: String): T + + def returnAssemblyStatement: ExecutableStatement + + def callChunk(label: ThingInMemory): List[T] + def areBlocksLarge(blocks: List[T]*): Boolean - def compileWhileStatement(ctx: CompilationContext, s: WhileStatement): List[T] = { + def compileWhileStatement(ctx: CompilationContext, s: WhileStatement): (List[T], List[T]) = { val start = ctx.nextLabel("wh") val middle = ctx.nextLabel("he") val inc = ctx.nextLabel("fp") val end = ctx.nextLabel("ew") val condType = AbstractExpressionCompiler.getExpressionType(ctx, s.condition) - val bodyBlock = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.body) - val incrementBlock = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.increment) + val (bodyBlock, extraBlock) = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.body) + val (incrementBlock, extraBlock2) = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.increment) val largeBodyBlock = areBlocksLarge(bodyBlock, incrementBlock) - condType match { + (condType match { case ConstantBooleanType(_, true) => List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten case ConstantBooleanType(_, false) => Nil @@ -63,18 +69,18 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case _ => ctx.log.error(s"Illegal type for a condition: `$condType`", s.condition.position) Nil - } + }) -> (extraBlock ++ extraBlock2) } - def compileDoWhileStatement(ctx: CompilationContext, s: DoWhileStatement): List[T] = { + def compileDoWhileStatement(ctx: CompilationContext, s: DoWhileStatement): (List[T], List[T]) = { val start = ctx.nextLabel("do") val inc = ctx.nextLabel("fp") val end = ctx.nextLabel("od") val condType = AbstractExpressionCompiler.getExpressionType(ctx, s.condition) - val bodyBlock = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.body) - val incrementBlock = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.increment) + val (bodyBlock, extraBlock) = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.body) + val (incrementBlock, extraBlock2) = compile(ctx.addLabels(s.labels, Label(end), Label(inc)), s.increment) val largeBodyBlock = areBlocksLarge(bodyBlock, incrementBlock) - condType match { + (condType match { case ConstantBooleanType(_, true) => val conditionBlock = compileExpressionForBranching(ctx, s.condition, NoBranching) List(labelChunk(start), bodyBlock, labelChunk(inc), incrementBlock, jmpChunk(start), labelChunk(end)).flatten @@ -98,10 +104,10 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case _ => ctx.log.error(s"Illegal type for a condition: `$condType`", s.condition.position) Nil - } + }) -> (extraBlock ++ extraBlock2) } - def compileForStatement(ctx: CompilationContext, f: ForStatement): List[T] = { + def compileForStatement(ctx: CompilationContext, f: ForStatement): (List[T], List[T]) = { // TODO: check sizes // TODO: special faster cases val p = f.position @@ -133,15 +139,17 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 => val end = ctx.nextLabel("of") - compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) ++ labelChunk(end) + val (main, extra) = compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) + main ++ labelChunk(end) -> extra case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s >= e => - Nil + Nil -> Nil case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e => val end = ctx.nextLabel("of") - compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) ++ labelChunk(end) + val (main, extra) = compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, f.start).pos(p) :: f.body) + main ++ labelChunk(end) -> extra case (ForDirection.To | ForDirection.ParallelTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s > e => - Nil + Nil -> Nil case (ForDirection.Until | ForDirection.ParallelUntil, Some(c), Some(NumericConstant(256, _))) if variable.map(_.typ.size).contains(1) && c.requiredSize == 1 && c.isProvablyNonnegative => @@ -163,9 +171,10 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e => val end = ctx.nextLabel("of") - compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, LiteralExpression(s, ssize)).pos(p) :: f.body) ++ labelChunk(end) + val (main, extra) = compile(ctx.addLabels(names, Label(end), Label(end)), Assignment(vex, LiteralExpression(s, ssize)).pos(p) :: f.body) + main ++ labelChunk(end) -> extra case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s < e => - Nil + Nil -> Nil case (ForDirection.DownTo, Some(NumericConstant(s, 1)), Some(NumericConstant(0, _))) if s > 0 => compile(ctx, List( Assignment( @@ -246,6 +255,131 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { } } + private def tryExtractForEachBodyToNewFunction(variable: String, stmts: List[ExecutableStatement]): (Boolean, List[ExecutableStatement]) = { + def inner2(stmt: ExecutableStatement): Option[ExecutableStatement] = stmt match { + case s: CompoundStatement => s.flatMap(inner2) + case _: BreakStatement => None + case _: ReturnStatement => None + case _: ContinueStatement => None + case s => Some(s) + } + def inner(stmt: ExecutableStatement): Option[ExecutableStatement] = stmt match { + case s: CompoundStatement => if (s.loopVariable == variable) s.flatMap(inner2) else s.flatMap(inner) + case _: BreakStatement => None + case _: ReturnStatement => None + case s@ContinueStatement(l) if l == variable => Some(returnAssemblyStatement.pos(s.position)) + case _: ContinueStatement => None + case s => Some(s) + } + def toplevel(stmt: ExecutableStatement): Option[ExecutableStatement] = stmt match { + case s: IfStatement => s.flatMap(toplevel) + case s: CompoundStatement => s.flatMap(inner) + case _: BreakStatement => None + case _: ReturnStatement => None + case s@ContinueStatement(l) if l == variable => Some(returnAssemblyStatement.pos(s.position)) + case s@ContinueStatement("") => Some(returnAssemblyStatement.pos(s.position)) + case s => Some(s) + } + val list = stmts.map(toplevel) + if (list.forall(_.isDefined)) true -> list.map(_.get) + else false -> stmts + } + + def compileForEachStatement(ctx: CompilationContext, f: ForEachStatement): (List[T], List[T]) = { + val values = f.values match { + case Left(expr) => + expr match { + case VariableExpression(id) => + ctx.env.maybeGet[Thing](id + ".array") match { + case Some(arr:MfArray) => + return compile(ctx, ForStatement( + f.variable, + LiteralExpression(0, 1), + LiteralExpression(arr.sizeInBytes, Constant.minimumSize(arr.sizeInBytes - 1)), + ForDirection.Until, + f.body + )) + case _ => + } + ctx.env.get[Thing](id) match { + case EnumType(_, Some(count)) => + return compile(ctx, ForStatement( + f.variable, + FunctionCallExpression(id, List(LiteralExpression(0, 1))), + FunctionCallExpression(id, List(LiteralExpression(count, 1))), + ForDirection.ParallelUntil, + f.body + )) + case _ => + } + case _ => + } + + return compile(ctx, ForStatement( + f.variable, + LiteralExpression(0, 1), + expr, + ForDirection.Until, + f.body + )) + case Right(vs) => vs + } + val endLabel = ctx.nextLabel("fe") + val continueLabelPlaceholder = ctx.nextLabel("fe") + val (inlinedBody, extra) = compile(ctx.addLabels(Set("", f.variable), Label(endLabel), Label(continueLabelPlaceholder)), f.body) + values.size match { + case 0 => Nil -> Nil + case 1 => + val tuple = compile(ctx, + Assignment( + VariableExpression(f.variable).pos(f.position), + values.head + ).pos(f.position) + ) + tuple._1 ++ inlinedBody -> tuple._2 + case valueCount => + val (extractable, extracted) = tryExtractForEachBodyToNewFunction(f.variable, f.body) + val (extractedBody, extra2) = compile(ctx.addStack(2), extracted :+ returnAssemblyStatement) + val inlinedBodySize = inlinedBody.map(_.sizeInBytes).sum + val extractedBodySize = extractedBody.map(_.sizeInBytes).sum + val sizeIfInlined = inlinedBodySize * valueCount + val sizeIfExtracted = extractedBodySize + 3 * valueCount + val expectedOptimizationPotentialFromInlining = valueCount * 2 + val shouldExtract = true + if (ctx.options.flag(CompilationFlag.OptimizeForSonicSpeed)) false + else sizeIfInlined - expectedOptimizationPotentialFromInlining > sizeIfExtracted + if (shouldExtract) { + if (extractable) { + val callLabel = ctx.nextLabel("fe") + val calls = values.flatMap(expr => compile(ctx, + Assignment( + VariableExpression(f.variable).pos(f.position), + expr + ) + )._1 ++ callChunk(Label(callLabel))) + return calls -> (labelChunk(callLabel) ++ extractedBody ++ extra ++ extra2) + } else { + ctx.log.warn("For loop too complex to extract, inlining", f.position) + } + } + + val inlinedEverything = values.flatMap { expr => + val tuple = compile(ctx, + Assignment( + VariableExpression(f.variable).pos(f.position), + expr + ) + ) + if (tuple._2.nonEmpty) ??? + val compiled = tuple._1 ++ inlinedBody + val continueLabel = ctx.nextLabel("fe") + compiled.map(replaceLabel(ctx, _, continueLabelPlaceholder, continueLabel)) ++ labelChunk(continueLabel) + } ++ labelChunk(endLabel) + + inlinedEverything -> extra + } + } + def compileBreakStatement(ctx: CompilationContext, s: BreakStatement) :List[T] = { ctx.breakLabels.get(s.label) match { case None => @@ -268,13 +402,13 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { } } - def compileIfStatement(ctx: CompilationContext, s: IfStatement): List[T] = { + def compileIfStatement(ctx: CompilationContext, s: IfStatement): (List[T], List[T]) = { val condType = AbstractExpressionCompiler.getExpressionType(ctx, s.condition) - val thenBlock = compile(ctx, s.thenBranch) - val elseBlock = compile(ctx, s.elseBranch) + val (thenBlock, extra1) = compile(ctx, s.thenBranch) + val (elseBlock, extra2) = compile(ctx, s.elseBranch) val largeThenBlock = areBlocksLarge(thenBlock) val largeElseBlock = areBlocksLarge(elseBlock) - condType match { + val mainCode: List[T] = condType match { case ConstantBooleanType(_, true) => compileExpressionForBranching(ctx, s.condition, NoBranching) ++ thenBlock case ConstantBooleanType(_, false) => @@ -373,6 +507,6 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { ctx.log.error(s"Illegal type for a condition: `$condType`", s.condition.position) Nil } - + mainCode -> (extra1 ++ extra2) } } diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index 1224e7ff..4373b6ad 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -115,6 +115,11 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements val (b, _) = optimizeStmts(body, Map()) val (i, _) = optimizeStmts(inc, Map()) DoWhileStatement(b, i, c, labels).pos(pos) -> Map() + case f@ForEachStatement(v, arr, body) => + for (a <- arr.right.getOrElse(Nil)) cv = search(a, cv) + val a = arr.map(_.map(optimizeExpr(_, Map()))) + val (b, _) = optimizeStmts(body, Map()) + ForEachStatement(v, a, b).pos(pos) -> Map() case f@ForStatement(v, st, en, dir, body) => maybeOptimizeForStatement(f) match { case Some(x) => x diff --git a/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala b/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala index 25faa178..0d831992 100644 --- a/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala +++ b/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala @@ -17,7 +17,7 @@ object MosBulkMemoryOperations { target.name != f.variable || target.index.containsVariable(f.variable) || !target.index.isPure || - f.direction == ForDirection.DownTo) return MosStatementCompiler.compileForStatement(ctx, f) + f.direction == ForDirection.DownTo) return MosStatementCompiler.compileForStatement(ctx, f)._1 ctx.env.getPointy(target.name) val sizeExpr = f.direction match { case ForDirection.DownTo => @@ -31,7 +31,7 @@ object MosBulkMemoryOperations { val w = ctx.env.get[Type]("word") val size = ctx.env.eval(sizeExpr) match { case Some(c) => c.quickSimplify - case _ => return MosStatementCompiler.compileForStatement(ctx, f) + case _ => return MosStatementCompiler.compileForStatement(ctx, f)._1 } val useTwoRegs = ctx.options.flag(CompilationFlag.OptimizeForSpeed) && ctx.options.zpRegisterSize >= 4 val loadReg = MosExpressionCompiler.compile(ctx, SumExpression(List(false -> f.start, false -> target.index), decimal = false), Some(w -> reg), BranchSpec.None) ++ ( diff --git a/src/main/scala/millfork/compiler/mos/MosCompiler.scala b/src/main/scala/millfork/compiler/mos/MosCompiler.scala index 8eaea63e..ed54e245 100644 --- a/src/main/scala/millfork/compiler/mos/MosCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosCompiler.scala @@ -16,7 +16,7 @@ object MosCompiler extends AbstractCompiler[AssemblyLine] { override def compile(ctx: CompilationContext): List[AssemblyLine] = { ctx.env.nameCheck(ctx.function.code) - val chunk = MosStatementCompiler.compile(ctx, new MosStatementPreprocessor(ctx, ctx.function.code)()) + val chunk = packHalves(MosStatementCompiler.compile(ctx, new MosStatementPreprocessor(ctx, ctx.function.code)())) val zpRegisterSize = ctx.options.zpRegisterSize val storeParamsFromRegisters = (ctx.function.params match { diff --git a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala index 5b211552..02826758 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala @@ -28,7 +28,16 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { MosExpressionCompiler.compile(ctx, expr, Some(b, RegisterVariable(MosRegister.A, b)), branching) } - def compile(ctx: CompilationContext, statement: ExecutableStatement): List[AssemblyLine] = { + override def replaceLabel(ctx: CompilationContext, line: AssemblyLine, from: String, to: String): AssemblyLine = line.parameter match { + case MemoryAddressConstant(Label(l)) if l == from => line.copy(parameter = MemoryAddressConstant(Label(to))) + case _ => line + } + + override def returnAssemblyStatement: ExecutableStatement = MosAssemblyStatement(RTS, AddrMode.Implied, LiteralExpression(0,1), Elidability.Elidable) + + override def callChunk(label: ThingInMemory): List[AssemblyLine] = List(AssemblyLine.absolute(JSR, label.toAddress)) + + def compile(ctx: CompilationContext, statement: ExecutableStatement): (List[AssemblyLine], List[AssemblyLine]) = { val env = ctx.env val m = ctx.function val b = env.get[Type]("byte") @@ -142,10 +151,10 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { List(AssemblyLine.implied(RTS)) }) } - (statement match { + val code: (List[AssemblyLine], List[AssemblyLine]) = statement match { case EmptyStatement(stmts) => stmts.foreach(s => compile(ctx, s)) - Nil + Nil -> Nil case MosAssemblyStatement(o, a, x, e) => val c: Constant = x match { // TODO: hmmm @@ -165,7 +174,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { case Indirect if o != JMP => IndexedZ case _ => a } - List(AssemblyLine(o, actualAddrMode, c, e)) + List(AssemblyLine(o, actualAddrMode, c, e)) -> Nil case RawBytesStatement(contents) => env.extractArrayContents(contents).map { expr => env.eval(expr) match { @@ -174,16 +183,16 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { ctx.log.error("Non-constant raw byte", position = statement.position) AssemblyLine(BYTE, RawByte, Constant.Zero, elidability = Elidability.Fixed) } - } + } -> Nil case Assignment(dest, source) => - MosExpressionCompiler.compileAssignment(ctx, source, dest) + MosExpressionCompiler.compileAssignment(ctx, source, dest) -> Nil case ExpressionStatement(e@FunctionCallExpression(name, params)) => env.lookupFunction(name, params.map(p => MosExpressionCompiler.getExpressionType(ctx, p) -> p)) match { case Some(i: MacroFunction) => val (paramPreparation, inlinedStatements) = MosMacroExpander.inlineFunction(ctx, i, params, e.position) - paramPreparation ++ compile(ctx.withInlinedEnv(i.environment, ctx.nextLabel("en")), inlinedStatements) + paramPreparation ++ compile(ctx.withInlinedEnv(i.environment, ctx.nextLabel("en")), inlinedStatements)._1 -> Nil case _ => - MosExpressionCompiler.compile(ctx, e, None, NoBranching) + MosExpressionCompiler.compile(ctx, e, None, NoBranching) -> Nil } case ExpressionStatement(e) => e match { @@ -191,11 +200,11 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { ctx.log.warn("Pointless expression statement", statement.position) case _ => } - MosExpressionCompiler.compile(ctx, e, None, NoBranching) + MosExpressionCompiler.compile(ctx, e, None, NoBranching) -> Nil case ReturnStatement(None) => // TODO: return type check // TODO: better stackpointer fix - ctx.function.returnType match { + (ctx.function.returnType match { case _: BooleanType => stackPointerFixBeforeReturn(ctx) ++ returnInstructions case t => t.size match { @@ -221,11 +230,11 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions } - } + }) -> Nil case s : ReturnDispatchStatement => - MosReturnDispatch.compile(ctx, s) + MosReturnDispatch.compile(ctx, s) -> Nil case ReturnStatement(Some(e)) => - m.returnType match { + (m.returnType match { case _: BooleanType => m.returnType.size match { case 0 => @@ -262,7 +271,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { MosExpressionCompiler.compileAssignment(ctx, e, VariableExpression(ctx.function.name + ".return")) ++ stackPointerFixBeforeReturn(ctx) ++ List(AssemblyLine.discardAF(), AssemblyLine.discardXF(), AssemblyLine.discardYF()) ++ returnInstructions } - } + }) -> Nil case s: IfStatement => compileIfStatement(ctx, s) case s: WhileStatement => @@ -270,14 +279,17 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { case s: DoWhileStatement => compileDoWhileStatement(ctx, s) case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) => - MosBulkMemoryOperations.compileMemset(ctx, target, source, f) + MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil case f:ForStatement => compileForStatement(ctx,f) + case f:ForEachStatement => + compileForEachStatement(ctx, f) case s:BreakStatement => - compileBreakStatement(ctx, s) + compileBreakStatement(ctx, s) -> Nil case s:ContinueStatement => - compileContinueStatement(ctx, s) - }).map(_.positionIfEmpty(statement.position)) + compileContinueStatement(ctx, s) -> Nil + } + code._1.map(_.positionIfEmpty(statement.position)) -> code._2.map(_.positionIfEmpty(statement.position)) } private def stackPointerFixBeforeReturn(ctx: CompilationContext, preserveA: Boolean = false, preserveX: Boolean = false, preserveY: Boolean = false): List[AssemblyLine] = { diff --git a/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala index 78b9c030..ede40d8f 100644 --- a/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala +++ b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala @@ -20,8 +20,8 @@ object Z80BulkMemoryOperations { * Compiles loops like for i,a,until,b { p[i] = q[i] } */ def compileMemcpy(ctx: CompilationContext, target: IndexedExpression, source: IndexedExpression, f: ForStatement): List[ZLine] = { - val sourceOffset = removeVariableOnce(f.variable, source.index).getOrElse(return compileForStatement(ctx, f)) - if (!sourceOffset.isPure) return compileForStatement(ctx, f) + val sourceOffset = removeVariableOnce(f.variable, source.index).getOrElse(return compileForStatement(ctx, f)._1) + if (!sourceOffset.isPure) return compileForStatement(ctx, f)._1 val sourceIndexExpression = SumExpression(List(false -> sourceOffset, false -> f.start), decimal = false) val calculateSource = Z80ExpressionCompiler.calculateAddressToHL(ctx, IndexedExpression(source.name, sourceIndexExpression)) compileMemoryBulk(ctx, target, f, @@ -106,7 +106,7 @@ object Z80BulkMemoryOperations { */ def compileMemtransform(ctx: CompilationContext, target: IndexedExpression, operator: String, source: Expression, f: ForStatement): List[ZLine] = { val c = determineExtraLoopRegister(ctx, f, source.containsVariable(f.variable)) - val load = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator, source, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator, source, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) import scala.util.control.Breaks._ breakable{ return compileMemoryBulk(ctx, target, f, @@ -117,7 +117,7 @@ object Z80BulkMemoryOperations { _ => None ) } - compileForStatement(ctx, f) + compileForStatement(ctx, f)._1 } /** @@ -131,8 +131,8 @@ object Z80BulkMemoryOperations { f: ForStatement): List[ZLine] = { import scala.util.control.Breaks._ val c = determineExtraLoopRegister(ctx, f, source1.containsVariable(f.variable) || source2.containsVariable(f.variable)) - val target1Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)) - val target2Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)) + val target1Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)._1) + val target2Offset = removeVariableOnce(f.variable, target2.index).getOrElse(return compileForStatement(ctx, f)._1) val target1IndexExpression = if (c.countDownDespiteSyntax) { SumExpression(List(false -> target1Offset, false -> f.end, true -> LiteralExpression(1, 1)), decimal = false) } else { @@ -148,8 +148,8 @@ object Z80BulkMemoryOperations { case _ => false }) if (fused) { - val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) - val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) val loads = load1 ++ load2 breakable{ return compileMemoryBulk(ctx, target1, f, @@ -164,12 +164,12 @@ object Z80BulkMemoryOperations { val goodness1 = goodnessForHL(ctx, operator1, source1) val goodness2 = goodnessForHL(ctx, operator2, source2) val loads = if (goodness1 <= goodness2) { - val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) - val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) load1 ++ load2 } else { - val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) - val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)) + val load1 = buildMemtransformLoader(ctx, ZRegister.MEM_HL, f.variable, operator1, source1, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) + val load2 = buildMemtransformLoader(ctx, ZRegister.MEM_DE, f.variable, operator2, source2, c.loopRegister).getOrElse(return compileForStatement(ctx, f)._1) load1 ++ load2 } val targetForDE = if (goodness1 <= goodness2) target1 else target2 @@ -187,7 +187,7 @@ object Z80BulkMemoryOperations { ) } } - compileForStatement(ctx, f) + compileForStatement(ctx, f)._1 } private case class ExtraLoopRegister(loopRegister: ZRegister.Value, initC: List[ZLine], nextC: List[ZLine], countDownDespiteSyntax: Boolean) @@ -398,8 +398,8 @@ object Z80BulkMemoryOperations { loadA: ZOpcode.Value => List[ZLine], z80Bulk: Boolean => Option[ZOpcode.Value]): List[ZLine] = { val one = LiteralExpression(1, 1) - val targetOffset = removeVariableOnce(f.variable, target.index).getOrElse(return compileForStatement(ctx, f)) - if (!targetOffset.isPure) return compileForStatement(ctx, f) + val targetOffset = removeVariableOnce(f.variable, target.index).getOrElse(return compileForStatement(ctx, f)._1) + if (!targetOffset.isPure) return compileForStatement(ctx, f)._1 val indexVariableSize = ctx.env.get[Variable](f.variable).typ.size val wrapper = createForLoopPreconditioningIfStatement(ctx, f) val decreasingDespiteSyntax = preferDecreasing && (f.direction == ForDirection.ParallelTo || f.direction == ForDirection.ParallelUntil) @@ -467,7 +467,7 @@ object Z80BulkMemoryOperations { Z80StatementCompiler.compile(ctx, IfStatement( FunctionCallExpression(operator, List(f.start, f.end)), List(Z80AssemblyStatement(ZOpcode.NOP, NoRegisters, None, LiteralExpression(0, 1), elidability = Elidability.Fixed)), - Nil)) + Nil))._1 } private def removeVariableOnce(variable: String, expr: Expression): Option[Expression] = { diff --git a/src/main/scala/millfork/compiler/z80/Z80Compiler.scala b/src/main/scala/millfork/compiler/z80/Z80Compiler.scala index f1f7f35d..0c56a62f 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Compiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Compiler.scala @@ -14,7 +14,7 @@ object Z80Compiler extends AbstractCompiler[ZLine] { override def compile(ctx: CompilationContext): List[ZLine] = { ctx.env.nameCheck(ctx.function.code) - val chunk = Z80StatementCompiler.compile(ctx, new Z80StatementPreprocessor(ctx, ctx.function.code)()) + val chunk = packHalves(Z80StatementCompiler.compile(ctx, new Z80StatementPreprocessor(ctx, ctx.function.code)())) val label = ZLine.label(Label(ctx.function.name)).copy(elidability = Elidability.Fixed) val storeParamsFromRegisters = ctx.function.params match { case NormalParamSignature(List(param)) if param.typ.size == 1 => diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index 7eb20604..4e37efe0 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -1,7 +1,7 @@ package millfork.compiler.z80 import millfork.CompilationFlag -import millfork.assembly.BranchingOpcodeMapping +import millfork.assembly.{BranchingOpcodeMapping, Elidability} import millfork.assembly.z80._ import millfork.compiler._ import millfork.env._ @@ -15,14 +15,14 @@ import millfork.error.ConsoleLogger object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { - def compile(ctx: CompilationContext, statement: ExecutableStatement): List[ZLine] = { + def compile(ctx: CompilationContext, statement: ExecutableStatement): (List[ZLine], List[ZLine])= { val options = ctx.options val env = ctx.env val ret = Z80Compiler.restoreRegistersAndReturn(ctx) - (statement match { + val code: (List[ZLine], List[ZLine]) = statement match { case EmptyStatement(stmts) => stmts.foreach(s => compile(ctx, s)) - Nil + Nil -> Nil case ReturnStatement(None) => fixStackOnReturn(ctx) ++ (ctx.function.returnType match { case _: BooleanType => @@ -34,9 +34,9 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { ctx.log.warn("Returning without a value", statement.position) List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE)) ++ ret } - }) + }) -> Nil case ReturnStatement(Some(e)) => - ctx.function.returnType match { + (ctx.function.returnType match { case t: BooleanType => t.size match { case 0 => ctx.log.error("Cannot return anything from a void function", statement.position) @@ -76,15 +76,15 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { List(ZLine.implied(DISCARD_F), ZLine.implied(DISCARD_A), ZLine.implied(DISCARD_HL), ZLine.implied(DISCARD_BC), ZLine.implied(DISCARD_DE), ZLine.implied(RET)) } - } + }) -> Nil case Assignment(destination, source) => val sourceType = AbstractExpressionCompiler.getExpressionType(ctx, source) - sourceType.size match { + (sourceType.size match { case 0 => ??? case 1 => Z80ExpressionCompiler.compileToA(ctx, source) ++ Z80ExpressionCompiler.storeA(ctx, destination, sourceType.isSigned) case 2 => Z80ExpressionCompiler.compileToHL(ctx, source) ++ Z80ExpressionCompiler.storeHL(ctx, destination, sourceType.isSigned) case s => Z80ExpressionCompiler.storeLarge(ctx, destination, source) - } + }) -> Nil case s: IfStatement => compileIfStatement(ctx, s) case s: WhileStatement => @@ -92,19 +92,19 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { case s: DoWhileStatement => compileDoWhileStatement(ctx, s) case s: ReturnDispatchStatement => - Z80ReturnDispatch.compile(ctx, s) + Z80ReturnDispatch.compile(ctx, s) -> Nil case f@ForStatement(_, _, _, _, List(Assignment(target: IndexedExpression, source: IndexedExpression))) => - Z80BulkMemoryOperations.compileMemcpy(ctx, target, source, f) + Z80BulkMemoryOperations.compileMemcpy(ctx, target, source, f) -> Nil case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !source.containsVariable(variable) => - Z80BulkMemoryOperations.compileMemset(ctx, target, source, f) + Z80BulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil case f@ForStatement(variable, _, _, _, List(ExpressionStatement(FunctionCallExpression( operator@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), List(target: IndexedExpression, source: Expression) )))) => - Z80BulkMemoryOperations.compileMemtransform(ctx, target, operator, source, f) + Z80BulkMemoryOperations.compileMemtransform(ctx, target, operator, source, f) -> Nil case f@ForStatement(variable, _, _, _, List( ExpressionStatement(FunctionCallExpression( @@ -116,7 +116,7 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { List(target2: IndexedExpression, source2: Expression) )) )) => - Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, operator2, source2, f) + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, operator2, source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( Assignment(target1: IndexedExpression, source1: Expression), @@ -125,7 +125,7 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { List(target2: IndexedExpression, source2: Expression) )) )) => - Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, operator2, source2, f) + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, operator2, source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( ExpressionStatement(FunctionCallExpression( @@ -134,30 +134,33 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { )), Assignment(target2: IndexedExpression, source2: Expression) )) => - Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, "=", source2, f) + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, "=", source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( Assignment(target1: IndexedExpression, source1: Expression), Assignment(target2: IndexedExpression, source2: Expression) )) => - Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, "=", source2, f) + Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, "=", source2, f) -> Nil case f: ForStatement => compileForStatement(ctx, f) + case f:ForEachStatement => + compileForEachStatement(ctx, f) case s: BreakStatement => - compileBreakStatement(ctx, s) + compileBreakStatement(ctx, s) -> Nil case s: ContinueStatement => - compileContinueStatement(ctx, s) + compileContinueStatement(ctx, s) -> Nil case ExpressionStatement(e@FunctionCallExpression(name, params)) => env.lookupFunction(name, params.map(p => Z80ExpressionCompiler.getExpressionType(ctx, p) -> p)) match { case Some(i: MacroFunction) => val (paramPreparation, inlinedStatements) = Z80MacroExpander.inlineFunction(ctx, i, params, e.position) - paramPreparation ++ compile(ctx.withInlinedEnv(i.environment, ctx.nextLabel("en")), inlinedStatements) + val (main, extra) = compile(ctx.withInlinedEnv(i.environment, ctx.nextLabel("en")), inlinedStatements) + paramPreparation ++ main -> extra case _ => - Z80ExpressionCompiler.compile(ctx, e, ZExpressionTarget.NOTHING) + Z80ExpressionCompiler.compile(ctx, e, ZExpressionTarget.NOTHING) -> Nil } case ExpressionStatement(e) => - Z80ExpressionCompiler.compile(ctx, e, ZExpressionTarget.NOTHING) + Z80ExpressionCompiler.compile(ctx, e, ZExpressionTarget.NOTHING) -> Nil case Z80AssemblyStatement(op, reg, offset, expression, elidability) => val param: Constant = expression match { // TODO: hmmm @@ -191,8 +194,9 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { } case _ => reg } - List(ZLine(op, registers, param, elidability)) - }).map(_.positionIfEmpty(statement.position)) + List(ZLine(op, registers, param, elidability)) -> Nil + } + code._1.map(_.positionIfEmpty(statement.position)) -> code._2.map(_.positionIfEmpty(statement.position)) } private def fixStackOnReturn(ctx: CompilationContext): List[ZLine] = { @@ -265,4 +269,13 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { override def compileExpressionForBranching(ctx: CompilationContext, expr: Expression, branching: BranchSpec): List[ZLine] = Z80ExpressionCompiler.compile(ctx, expr, ZExpressionTarget.NOTHING, branching) + + override def replaceLabel(ctx: CompilationContext, line: ZLine, from: String, to: String): ZLine = line.parameter match { + case MemoryAddressConstant(Label(l)) if l == from => line.copy(parameter = MemoryAddressConstant(Label(to))) + case _ => line + } + + override def returnAssemblyStatement: ExecutableStatement = Z80AssemblyStatement(RET, NoRegisters, None, LiteralExpression(0,1), Elidability.Elidable) + + override def callChunk(label: ThingInMemory): List[ZLine] = List(ZLine(CALL, NoRegisters, label.toAddress)) } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index e844188e..07648192 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -224,6 +224,8 @@ trait MfArray extends ThingInMemory with IndexableThing { def indexType: VariableType def elementType: VariableType override def isVolatile: Boolean = false + /* TODO: what if larger elements? */ + def sizeInBytes: Int } case class UninitializedArray(name: String, /* TODO: what if larger elements? */ sizeInBytes: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val alignment: MemoryAlignment) extends MfArray with UninitializedMemory { @@ -256,6 +258,8 @@ case class InitializedArray(name: String, address: Option[Constant], contents: L override def bank(compilationOptions: CompilationOptions): String = declaredBank.getOrElse(compilationOptions.platform.defaultCodeBank) override def zeropage: Boolean = false + + override def sizeInBytes: Int = contents.size } case class RelativeVariable(name: String, address: Constant, typ: Type, zeropage: Boolean, declaredBank: Option[String], override val isVolatile: Boolean) extends VariableInMemory { diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index a0c2f53e..295f26fa 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -313,6 +313,10 @@ case class RawBytesStatement(contents: ArrayContents) extends ExecutableStatemen sealed trait CompoundStatement extends ExecutableStatement { def getChildStatements: Seq[Statement] + + def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] + + def loopVariable: String } case class ExpressionStatement(expression: Expression) extends ExecutableStatement { @@ -363,12 +367,30 @@ case class IfStatement(condition: Expression, thenBranch: List[ExecutableStateme override def getAllExpressions: List[Expression] = condition :: (thenBranch ++ elseBranch).flatMap(_.getAllExpressions) override def getChildStatements: Seq[Statement] = thenBranch ++ elseBranch + + override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { + val t = thenBranch.map(f) + val e = elseBranch.map(f) + if (t.forall(_.isDefined) && e.forall(_.isDefined)) Some(IfStatement(condition, t.map(_.get), e.map(_.get)).pos(this.position)) + else None + } + + override def loopVariable: String = "-none-" } case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) override def getChildStatements: Seq[Statement] = body ++ increment + + override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { + val b = body.map(f) + val i = increment.map(f) + if (b.forall(_.isDefined) && i.forall(_.isDefined)) Some(WhileStatement(condition, b.map(_.get), i.map(_.get), labels).pos(this.position)) + else None + } + + override def loopVariable: String = "-none-" } object ForDirection extends Enumeration { @@ -379,12 +401,42 @@ case class ForStatement(variable: String, start: Expression, end: Expression, di override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: body.flatMap(_.getAllExpressions) override def getChildStatements: Seq[Statement] = body + + override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { + val b = body.map(f) + if (b.forall(_.isDefined)) Some(ForStatement(variable, start, end, direction, b.map(_.get)).pos(this.position)) + else None + } + + override def loopVariable: String = variable +} + +case class ForEachStatement(variable: String, values: Either[Expression, List[Expression]], body: List[ExecutableStatement]) extends CompoundStatement { + override def getAllExpressions: List[Expression] = VariableExpression(variable) :: (values.fold(List(_), identity) ++ body.flatMap(_.getAllExpressions)) + + override def getChildStatements: Seq[Statement] = body + override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { + val b = body.map(f) + if (b.forall(_.isDefined)) Some(ForEachStatement(variable,values, b.map(_.get)).pos(this.position)) + else None + } + + override def loopVariable: String = variable } case class DoWhileStatement(body: List[ExecutableStatement], increment: List[ExecutableStatement], condition: Expression, labels: Set[String] = Set("", "do")) extends CompoundStatement { override def getAllExpressions: List[Expression] = condition :: body.flatMap(_.getAllExpressions) override def getChildStatements: Seq[Statement] = body ++ increment + + override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { + val b = body.map(f) + val i = increment.map(f) + if (b.forall(_.isDefined) && i.forall(_.isDefined)) Some(DoWhileStatement(b.map(_.get), i.map(_.get), condition, labels).pos(this.position)) + else None + } + + override def loopVariable: String = "-none-" } case class BreakStatement(label: String) extends ExecutableStatement { diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 9218b9d0..acd7a38a 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -331,6 +331,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri ifStatement | whileStatement | forStatement | + forEachStatement | doWhileStatement | breakStatement | continueStatement | @@ -390,13 +391,19 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } yield Seq(WhileStatement(condition, body.toList, Nil)) def forStatement: P[Seq[ExecutableStatement]] = for { - identifier <- "for" ~ SWS ~/ identifier ~/ HWS ~ "," ~/ HWS ~ Pass + identifier <- "for" ~ SWS ~ identifier ~ HWS ~ "," ~/ HWS ~ Pass start <- mfExpression(nonStatementLevel, false) ~ HWS ~ "," ~/ HWS ~/ Pass direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass end <- mfExpression(nonStatementLevel, false) body <- AWS ~ executableStatements } yield Seq(ForStatement(identifier, start, end, direction, body.toList)) + def forEachStatement: P[Seq[ExecutableStatement]] = for { + id <- "for" ~ SWS ~/ identifier ~/ HWS ~ ":" ~/ HWS ~ Pass + values <- ("[" ~/ AWS ~/ mfExpression(0, false).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ "]" ~/ "").map(seq => Right(seq.toList)) | mfExpression(0, false).map(Left(_)) + body <- AWS ~ executableStatements + } yield Seq(ForEachStatement(id, values, body.toList)) + def inlineAssembly: P[Seq[ExecutableStatement]] = "asm" ~ !letterOrDigit ~/ AWS ~ asmStatements //noinspection MutatorLikeMethodIsParameterless diff --git a/src/test/scala/millfork/test/ForLoopSuite.scala b/src/test/scala/millfork/test/ForLoopSuite.scala index 138fd20d..1f2c424b 100644 --- a/src/test/scala/millfork/test/ForLoopSuite.scala +++ b/src/test/scala/millfork/test/ForLoopSuite.scala @@ -114,6 +114,21 @@ class ForLoopSuite extends FunSuite with Matchers { | } """.stripMargin)(_.readByte(0xc000) should equal(15)) } + + test("For-until 2") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | word output @$c000 + | void main () { + | byte i + | output = 0 + | for i : 6 { + | output += i + | } + | } + """.stripMargin)(_.readByte(0xc000) should equal(15)) + } + test("For-parallelto") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( """ @@ -127,6 +142,7 @@ class ForLoopSuite extends FunSuite with Matchers { | } """.stripMargin)(_.readByte(0xc000) should equal(15)) } + test("For-paralleluntil") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( """ @@ -198,6 +214,24 @@ class ForLoopSuite extends FunSuite with Matchers { } } + test("Memcpy 2") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | array output[5]@$c001 + | array input = [0,1,4,9,16,25,36,49] + | void main () { + | byte i + | for i : output { + | output[i] = input[i+1] + | } + | } + | void _panic(){while(true){}} + """.stripMargin){ m=> + m.readByte(0xc001) should equal (1) + m.readByte(0xc005) should equal (25) + } + } + test("Memset with index") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( """ @@ -215,6 +249,23 @@ class ForLoopSuite extends FunSuite with Matchers { } } + test("Memset with index 2") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | array output[5]@$c001 + | void main () { + | byte i + | for i : output { + | output[i] = 22 + | } + | } + | void _panic(){while(true){}} + """.stripMargin){ m=> + m.readByte(0xc001) should equal (22) + m.readByte(0xc005) should equal (22) + } + } + test("Memset with pointer") { EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( """ @@ -339,4 +390,30 @@ class ForLoopSuite extends FunSuite with Matchers { """.stripMargin) } + + test("For each") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | array output[$400]@$c000 + | void main () { + | pointer p + | for p:[$c000, $c003, $c005, $c007]{ + | p[0] = 34 + | } + | for p:[$c001, $c004, $c006, $c008]{ + | p[0] = 42 + | } + | } + | void _panic(){while(true){}} + """.stripMargin) { m => + m.readByte(0xc000) should equal(34) + m.readByte(0xc003) should equal(34) + m.readByte(0xc005) should equal(34) + m.readByte(0xc007) should equal(34) + m.readByte(0xc001) should equal(42) + m.readByte(0xc004) should equal(42) + m.readByte(0xc006) should equal(42) + m.readByte(0xc008) should equal(42) + } + } }