From d5367cc1fe78d6b8611dc0710bd2094b823b43ea Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 24 Jul 2020 19:11:27 +0200 Subject: [PATCH] `for` loops over arrays --- docs/lang/syntax.md | 9 ++ .../compiler/AbstractStatementCompiler.scala | 86 ++++++++++++++++--- .../AbstractStatementPreprocessor.scala | 12 ++- .../millfork/compiler/MacroExpander.scala | 4 +- .../compiler/mos/MosStatementCompiler.scala | 6 +- .../mos/MosStatementPreprocessor.scala | 2 +- .../compiler/z80/Z80StatementCompiler.scala | 14 +-- .../z80/Z80StatementPreprocessor.scala | 2 +- src/main/scala/millfork/env/Environment.scala | 5 ++ src/main/scala/millfork/node/Node.scala | 13 +-- src/main/scala/millfork/parser/MfParser.scala | 8 +- .../scala/millfork/test/ForLoopSuite.scala | 85 ++++++++++++++++++ 12 files changed, 207 insertions(+), 39 deletions(-) diff --git a/docs/lang/syntax.md b/docs/lang/syntax.md index 1c1a6a3d..399c63f3 100644 --- a/docs/lang/syntax.md +++ b/docs/lang/syntax.md @@ -371,12 +371,18 @@ for , , , { } for : { } +for : { +} +for , : { +} for : [ ] { } ``` * `` – an already defined numeric variable +* `` – an already defined numeric variable + * `` – the type of range to traverse: * `to` – from `` inclusive to `` inclusive, in ascending order @@ -411,6 +417,9 @@ for : [ ] { * `` – traverse enum constants of given type, in arbitrary order +* `` – traverse array contents, in arbitrary order, +assigning the index to `` and either the element or the pointer to the element to `` + * `` – traverse every value in the list, in the given order. Values do not have to be constant. If a value is not a constant and its value changes while executing the loop, the behaviour is undefined. diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 0b0950b8..90fa0136 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -112,6 +112,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { def compileForStatement(ctx: CompilationContext, f: ForStatement): (List[T], List[T]) = { // TODO: check sizes // TODO: special faster cases + val extraIncrement = f.extraIncrement val p = f.position val vex = VariableExpression(f.variable) val indexType = ctx.env.get[Variable](f.variable).typ @@ -175,13 +176,13 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case (ForDirection.To | ForDirection.ParallelTo, _, Some(NumericConstant(255, _))) if indexType.size == 1 => compile(ctx, List( Assignment(vex, f.start).pos(p), - DoWhileStatement(f.body, List(increment), FunctionCallExpression("!=", List(vex, LiteralExpression(0, 1).pos(p))), names).pos(p) + DoWhileStatement(f.body, increment :: extraIncrement, FunctionCallExpression("!=", List(vex, LiteralExpression(0, 1).pos(p))), names).pos(p) )) case (ForDirection.To | ForDirection.ParallelTo, _, Some(NumericConstant(0xffff, _))) if indexType.size == 2 => compile(ctx, List( Assignment(vex, f.start).pos(p), - DoWhileStatement(f.body, List(increment), FunctionCallExpression("!=", List(vex, LiteralExpression(0, 2).pos(p))), names).pos(p) + DoWhileStatement(f.body, increment :: extraIncrement, FunctionCallExpression("!=", List(vex, LiteralExpression(0, 2).pos(p))), names).pos(p) )) case (ForDirection.Until | ForDirection.ParallelUntil, Some(c), Some(NumericConstant(256, _))) @@ -193,19 +194,19 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { // BNE loop compile(ctx, List( Assignment(vex, f.start).pos(p), - DoWhileStatement(f.body, List(increment), FunctionCallExpression("!=", List(vex, LiteralExpression(0, 1).pos(p))), names).pos(p) + DoWhileStatement(f.body, increment :: extraIncrement, FunctionCallExpression("!=", List(vex, LiteralExpression(0, 1).pos(p))), names).pos(p) )) case (ForDirection.ParallelUntil, Some(NumericConstant(0, _)), Some(NumericConstant(e, _))) if e > 0 => compile(ctx, List( Assignment(vex, f.end), - DoWhileStatement(Nil, decrement :: f.body, FunctionCallExpression("!=", List(vex, f.start)), names).pos(p) + DoWhileStatement(Nil, decrement :: (f.body ++ extraIncrement), FunctionCallExpression("!=", List(vex, f.start)), names).pos(p) )) case (ForDirection.Until, Some(NumericConstant(s, _)), Some(NumericConstant(e, _))) if s >= 0 && e > 0 && s < e => compile(ctx, List( Assignment(vex, f.start).pos(p), - DoWhileStatement(f.body, increment::Nil, FunctionCallExpression("!=", List(vex, f.end)).pos(p), names).pos(p) + DoWhileStatement(f.body, increment :: extraIncrement, FunctionCallExpression("!=", List(vex, f.end)).pos(p), names).pos(p) )) case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, esize))) if s == e => @@ -222,7 +223,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { ).pos(p), DoWhileStatement( decrement :: f.body, - Nil, + extraIncrement, FunctionCallExpression("!=", List(vex, f.end)).pos(p), names).pos(p) )) case (ForDirection.DownTo, Some(NumericConstant(s, ssize)), Some(NumericConstant(0, _))) if s > 0 => @@ -233,7 +234,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { ).pos(p), DoWhileStatement( decrement :: f.body, - Nil, + extraIncrement, FunctionCallExpression("!=", List(vex, f.end)).pos(p), names ).pos(p) @@ -245,7 +246,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { Assignment(vex, f.start).pos(p), WhileStatement( FunctionCallExpression("!=", List(vex, f.end)).pos(p), - f.body, List(increment), names).pos(p) + f.body, increment :: extraIncrement, names).pos(p) )) // case (ForDirection.To | ForDirection.ParallelTo, _, Some(NumericConstant(n, _))) if n > 0 && n < 255 => // compile(ctx, List( @@ -263,7 +264,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { List(IfStatement( FunctionCallExpression("==", List(vex, f.end)).pos(p), List(BreakStatement(f.variable).pos(p)), - List(increment) + increment :: extraIncrement )), names) )) @@ -274,7 +275,7 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { Assignment(vex, f.start).pos(p), DoWhileStatement( f.body, - List(decrement), + decrement :: extraIncrement, FunctionCallExpression("!=", List(vex, endMinusOne)).pos(p), names ).pos(p) @@ -317,8 +318,11 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { case Left(expr) => expr match { case VariableExpression(id) => - ctx.env.get[Thing](id) match { - case EnumType(_, Some(count)) => + (ctx.env.maybeGet[Thing](id), ctx.env.maybeGet[Thing](id + ".array")) match { + case (Some(EnumType(_, Some(count))), _) => + if (f.pointerVariable.isDefined) { + ctx.log.error("You can use only one variable when iteration over an enum type", f.position) + } return compile(ctx, ForStatement( f.variable, FunctionCallExpression(id, List(LiteralExpression(0, 1))), @@ -326,6 +330,64 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { ForDirection.ParallelUntil, f.body )) + case pair@( + (Some(ConstantThing(_, MemoryAddressConstant(_: MfArray), _)), _) | + (_, Some(_: MfArray)) + ) => + val arr: MfArray = pair match { + case (_, Some(a: MfArray)) => a + case (Some(ConstantThing(_, MemoryAddressConstant(a: MfArray) ,_)), _) => a + case _ => ??? + } + val (initialAssignment, inLoopAssignment, extraIncrement, orderImportant): + (List[ExecutableStatement], List[ExecutableStatement], List[ExecutableStatement], Boolean) = f.pointerVariable match { + case Some(pv) => + ctx.env.maybeGet[Variable](pv) match { + case Some(v: Variable) => + val elTyp = arr.elementType + val isValue = elTyp.isAssignableTo(v.typ) + val isPointer = v.typ match { + case PointerType(_, targetName, _) => elTyp.name == targetName + case _ => false + } + if (!isValue && !isPointer) { + ctx.log.error(s"Incompatible type for second iteration variable: got ${v.typ.name}, required ${elTyp.name} or pointer.${elTyp.name}", f.position) + } + val initialAss = if (isPointer) { + List(Assignment( + VariableExpression(pv), + VariableExpression(arr.name.stripSuffix(".array") + ".pointer") + )) + } else Nil + val inLoopAss = if (isValue) { + List(Assignment( + VariableExpression(pv), + IndexedExpression(arr.name, VariableExpression(f.variable)) + )) + } else Nil + val increment = if (isPointer) { + List(ExpressionStatement(FunctionCallExpression("+=", List( + VariableExpression(pv + ".raw"), + FunctionCallExpression("sizeof", List(VariableExpression(elTyp.name))) + )))) + } else Nil + (initialAss, inLoopAss, increment, isPointer) + case None => + ctx.log.error(s"Undefined variable: ${pv}", f.position) + (Nil, Nil, Nil, true) + } + case None => + (Nil, Nil, Nil, true) + } + val usesIterationVariable = f.body.exists(_.getAllExpressions.exists(_.containsVariable(f.variable))) + return compile(ctx, initialAssignment :+ ForStatement( + f.variable, + LiteralExpression(0, 1), + LiteralExpression(arr.elementCount, Constant.minimumSize(arr.elementCount)), + if (usesIterationVariable && orderImportant) ForDirection.Until else ForDirection.ParallelUntil, + inLoopAssignment ++ f.body, + extraIncrement + )) case _ => } case _ => diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index 728f066d..cea0de06 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -199,12 +199,12 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte val (b, _) = optimizeStmts(body, Map()) val (i, _) = optimizeStmts(inc, Map()) DoWhileStatement(b, i, c, labels).pos(pos) -> Map() - case f@ForEachStatement(v, arr, body) => + case f@ForEachStatement(v, pv, arr, body) => for (a <- arr.right.getOrElse(Nil)) cv = search(a, cv) val a = arr.map(_.map(optimizeExpr(_, Map()))) val (b, _) = optimizeStmts(body, Map()) - ForEachStatement(v, a, b).pos(pos) -> Map() - case f@ForStatement(v, st, en, dir, body) => + ForEachStatement(v, pv, a, b).pos(pos) -> Map() + case f@ForStatement(v, st, en, dir, body, Nil) => // detect a memset f.body match { @@ -269,6 +269,12 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte val (b, _) = optimizeStmts(body, Map()) ForStatement(v, s, e, dir, b).pos(pos) -> Map() } + case f@ForStatement(v, st, en, dir, body, increment) => + val s = optimizeExpr(st, Map()) + val e = optimizeExpr(en, Map()) + val (b, _) = optimizeStmts(body, Map()) + val (i, _) = optimizeStmts(increment, Map()) + ForStatement(v, s, e, dir, b, i).pos(pos) -> Map() case _ => stmt -> Map() } } diff --git a/src/main/scala/millfork/compiler/MacroExpander.scala b/src/main/scala/millfork/compiler/MacroExpander.scala index 7df5a696..2efeefa1 100644 --- a/src/main/scala/millfork/compiler/MacroExpander.scala +++ b/src/main/scala/millfork/compiler/MacroExpander.scala @@ -43,7 +43,7 @@ abstract class MacroExpander[T <: AbstractCode] { }) case WhileStatement(c, b, i, n) => WhileStatement(f(c), b.map(gx), i.map(gx), n) case DoWhileStatement(b, i, c, n) => DoWhileStatement(b.map(gx), i.map(gx), f(c), n) - case ForStatement(v, start, end, dir, body) => ForStatement(h(v), f(start), f(end), dir, body.map(gx)) + case ForStatement(v, start, end, dir, body, increment) => ForStatement(h(v), f(start), f(end), dir, body.map(gx), increment.map(gx)) case MemsetStatement(start, size, value, dir, original) => MemsetStatement(f(start), size, f(value), dir, original.map(gx).asInstanceOf[Option[ForStatement]]) case IfStatement(c, t, e) => IfStatement(f(c), t.map(gx), e.map(gx)) case s: Z80AssemblyStatement => s.copy(expression = f(s.expression), offsetExpression = s.offsetExpression.map(f)) @@ -86,7 +86,7 @@ abstract class MacroExpander[T <: AbstractCode] { }) case WhileStatement(c, b, i, n) => WhileStatement(f(c), b.map(gx), i.map(gx), n) case DoWhileStatement(b, i, c, n) => DoWhileStatement(b.map(gx), i.map(gx), f(c), n) - case ForStatement(v, start, end, dir, body) => ForStatement(h(v), f(start), f(end), dir, body.map(gx)) + case ForStatement(v, start, end, dir, body, increment) => ForStatement(h(v), f(start), f(end), dir, body.map(gx), increment.map(gx)) case MemsetStatement(start, size, value, dir, original) => MemsetStatement(f(start), size, f(value), dir, original.map(gx).asInstanceOf[Option[ForStatement]]) case IfStatement(c, t, e) => IfStatement(f(c), t.map(gx), e.map(gx)) case s: Z80AssemblyStatement => s.copy(expression = f(s.expression), offsetExpression = s.offsetExpression.map(f)) diff --git a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala index 952ad947..47572a70 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementCompiler.scala @@ -305,11 +305,11 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { compileDoWhileStatement(ctx, s) case f:MemsetStatement => MosBulkMemoryOperations.compileMemset(ctx, f) -> Nil - case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if !ctx.env.overlapsVariable(variable, source) => + case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression)), Nil) if !ctx.env.overlapsVariable(variable, source) => MosBulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil case f@ForStatement(variable, start, end, _, List(ExpressionStatement( FunctionCallExpression(operator@("+=" | "-=" | "+'=" | "-'=" | "|=" | "^=" | "&="), List(target: VariableExpression, source)) - ))) if !ctx.env.overlapsVariable(variable, source) && + )), Nil) if !ctx.env.overlapsVariable(variable, source) && !ctx.env.overlapsVariable(variable, target) && !ctx.env.overlapsVariable(target.name, start) && !ctx.env.overlapsVariable(target.name, end) => @@ -319,7 +319,7 @@ object MosStatementCompiler extends AbstractStatementCompiler[AssemblyLine] { } case f@ForStatement(variable, start, end, _, List(ExpressionStatement( FunctionCallExpression(operator@("+=" | "-=" | "<<=" | ">>="), List(target: IndexedExpression, source)) - ))) if !ctx.env.overlapsVariable(variable, source) && + )), Nil) if !ctx.env.overlapsVariable(variable, source) && !ctx.env.overlapsVariable(target.name, start) && !ctx.env.overlapsVariable(target.name, end) && target.name != variable => MosBulkMemoryOperations.compileMemmodify(ctx, target, operator, source, f) match { diff --git a/src/main/scala/millfork/compiler/mos/MosStatementPreprocessor.scala b/src/main/scala/millfork/compiler/mos/MosStatementPreprocessor.scala index 90402e1f..360fd4f6 100644 --- a/src/main/scala/millfork/compiler/mos/MosStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/mos/MosStatementPreprocessor.scala @@ -11,7 +11,7 @@ import millfork.node._ class MosStatementPreprocessor(ctx: CompilationContext, statements: List[ExecutableStatement]) extends AbstractStatementPreprocessor(ctx, statements) { def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = { - if (optimize && !f.variable.contains(".") && env.get[Variable](f.variable).typ.size == 2) { + if (f.extraIncrement.isEmpty && optimize && !f.variable.contains(".") && env.get[Variable](f.variable).typ.size == 2) { (env.eval(f.start), env.eval(f.end)) match { case (Some(NumericConstant(s, _)), Some(NumericConstant(e, _))) if (s & 0xffff) == s && (e & 0xffff) == e => f.direction match { diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala index b70ad5d3..fe333bd5 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementCompiler.scala @@ -143,16 +143,16 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { case f:MemsetStatement => Z80BulkMemoryOperations.compileMemset(ctx, f) -> Nil - case f@ForStatement(_, _, _, _, List(Assignment(target: IndexedExpression, source: IndexedExpression))) => + case f@ForStatement(_, _, _, _, List(Assignment(target: IndexedExpression, source: IndexedExpression)), Nil) => Z80BulkMemoryOperations.compileMemcpy(ctx, target, source, f) -> Nil - case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression))) if ctx.env.overlapsVariable(variable, source) => + case f@ForStatement(variable, _, _, _, List(Assignment(target: IndexedExpression, source: Expression)), Nil) if ctx.env.overlapsVariable(variable, source) => Z80BulkMemoryOperations.compileMemset(ctx, target, source, f) -> Nil case f@ForStatement(variable, _, _, _, List(ExpressionStatement(FunctionCallExpression( operator@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), List(target: IndexedExpression, source: Expression) - )))) => + ))), Nil) => Z80BulkMemoryOperations.compileMemtransform(ctx, target, operator, source, f) -> Nil case f@ForStatement(variable, _, _, _, List( @@ -164,7 +164,7 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { operator2@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), List(target2: IndexedExpression, source2: Expression) )) - )) => + ), Nil) => Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, operator2, source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( @@ -173,7 +173,7 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { operator2@("+=" | "-=" | "|=" | "&=" | "^=" | "+'=" | "-'=" | "<<=" | ">>="), List(target2: IndexedExpression, source2: Expression) )) - )) => + ), Nil) => Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, operator2, source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( @@ -182,13 +182,13 @@ object Z80StatementCompiler extends AbstractStatementCompiler[ZLine] { List(target1: IndexedExpression, source1: Expression) )), Assignment(target2: IndexedExpression, source2: Expression) - )) => + ), Nil) => Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, operator1, source1, target2, "=", source2, f) -> Nil case f@ForStatement(variable, _, _, _, List( Assignment(target1: IndexedExpression, source1: Expression), Assignment(target2: IndexedExpression, source2: Expression) - )) => + ), Nil) => Z80BulkMemoryOperations.compileMemtransform2(ctx, target1, "=", source1, target2, "=", source2, f) -> Nil case f: ForStatement => diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala b/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala index 0dbebf43..1331e952 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala @@ -39,7 +39,7 @@ class Z80StatementPreprocessor(ctx: CompilationContext, statements: List[Executa def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] = { - if (!ctx.options.flag(CompilationFlag.DangerousOptimizations)) return None + if (f.extraIncrement.isEmpty && !ctx.options.flag(CompilationFlag.DangerousOptimizations)) return None // TODO: figure out when this is useful // Currently all instances of arr[i] are replaced with arr`popt##`i[0], where arr`popt`i is a new pointer variable. // This breaks the main Millfork promise of not using hidden variables! diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 2e2b2578..8033146b 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -2328,6 +2328,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa nameCheck(s.start) nameCheck(s.end) nameCheck(s.body) + nameCheck(s.extraIncrement) + case s: ForEachStatement => + checkName[Variable]("Variable", s.variable, s.position) + s.pointerVariable.foreach(pv => checkName[Variable]("Variable", pv, s.position)) + nameCheck(s.body) case s:IfStatement => nameCheck(s.condition) nameCheck(s.thenBranch) diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index c5900f0c..a1679d75 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -755,14 +755,15 @@ object ForDirection extends Enumeration { val To, Until, DownTo, ParallelTo, ParallelUntil = Value } -case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement]) extends CompoundStatement { - override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: body.flatMap(_.getAllExpressions) +case class ForStatement(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: List[ExecutableStatement], extraIncrement: List[ExecutableStatement] = Nil) extends CompoundStatement { + override def getAllExpressions: List[Expression] = VariableExpression(variable) :: start :: end :: (body.flatMap(_.getAllExpressions) ++ extraIncrement.flatMap(_.getAllExpressions)) - override def getChildStatements: Seq[Statement] = body + override def getChildStatements: Seq[Statement] = body ++ extraIncrement override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { val b = body.map(f) - if (b.forall(_.isDefined)) Some(ForStatement(variable, start, end, direction, b.map(_.get)).pos(this.position)) + val i = extraIncrement.map(f) + if (b.forall(_.isDefined) && i.forall(_.isDefined)) Some(ForStatement(variable, start, end, direction, b.map(_.get), i.map(_.get)).pos(this.position)) else None } @@ -785,13 +786,13 @@ case class MemsetStatement(start: Expression, size: Constant, value: Expression, override def loopVariable: String = original.fold("_none")(_.loopVariable) } -case class ForEachStatement(variable: String, values: Either[Expression, List[Expression]], body: List[ExecutableStatement]) extends CompoundStatement { +case class ForEachStatement(variable: String, pointerVariable: Option[String], values: Either[Expression, List[Expression]], body: List[ExecutableStatement]) extends CompoundStatement { override def getAllExpressions: List[Expression] = VariableExpression(variable) :: (values.fold[List[Expression]](_ => Nil, identity) ++ body.flatMap(_.getAllExpressions)) override def getChildStatements: Seq[Statement] = body override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { val b = body.map(f) - if (b.forall(_.isDefined)) Some(ForEachStatement(variable,values, b.map(_.get)).pos(this.position)) + if (b.forall(_.isDefined)) Some(ForEachStatement(variable, pointerVariable, values, b.map(_.get)).pos(this.position)) else None } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index cec3fe00..cdc94010 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -544,8 +544,8 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri labelStatement | ifStatement | whileStatement | - forStatement | forEachStatement | + forStatement | doWhileStatement | breakStatement | continueStatement | @@ -618,7 +618,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } yield Seq(WhileStatement(condition, body.toList, Nil)) def forStatement: P[Seq[ExecutableStatement]] = for { - identifier <- "for" ~ SWS ~ identifier ~ HWS ~ "," ~/ HWS ~ Pass + identifier <- "for" ~/ SWS ~/ identifier ~/ HWS ~ "," ~/ HWS ~ Pass start <- mfExpression(nonStatementLevel, false) ~ HWS ~ "," ~/ HWS ~/ Pass direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass end <- mfExpression(nonStatementLevel, false) @@ -626,10 +626,10 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } yield Seq(ForStatement(identifier, start, end, direction, body.toList)) def forEachStatement: P[Seq[ExecutableStatement]] = for { - id <- "for" ~ SWS ~/ identifier ~/ HWS ~ ":" ~/ HWS ~ Pass + id <- "for" ~ SWS ~ identifier ~ HWS ~ ("," ~ HWS ~ identifier ~ HWS).? ~ ":" ~/ HWS ~ Pass values <- ("[" ~/ AWS ~/ mfExpression(0, false).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ "]" ~/ "").map(seq => Right(seq.toList)) | mfExpression(0, false).map(Left(_)) body <- AWS ~ executableStatements - } yield Seq(ForEachStatement(id, values, body.toList)) + } yield Seq(ForEachStatement(id._1, id._2, values, body.toList)) def inlineAssembly: P[Seq[ExecutableStatement]] = "asm" ~ !letterOrDigit ~/ AWS ~ asmStatements diff --git a/src/test/scala/millfork/test/ForLoopSuite.scala b/src/test/scala/millfork/test/ForLoopSuite.scala index e509b8bf..c90addfc 100644 --- a/src/test/scala/millfork/test/ForLoopSuite.scala +++ b/src/test/scala/millfork/test/ForLoopSuite.scala @@ -507,4 +507,89 @@ class ForLoopSuite extends FunSuite with Matchers { // OK } } + + + test("Looping across arrays") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | + | array source = [1,2,3,4,5] + | array target[5] @$c000 + | void main() { + | byte i + | for i:source { + | target[i] = source[i] + | } + | } + |""".stripMargin) { m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(2) + m.readByte(0xc002) should equal(3) + m.readByte(0xc003) should equal(4) + m.readByte(0xc004) should equal(5) + } + } + test("Looping across arrays 2") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | + | array source = [1,2,3,4,5] + | array target[5] @$c000 + | void main() { + | byte i + | byte w + | for i,w:source { + | target[i] = w + | } + | } + |""".stripMargin) { m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(2) + m.readByte(0xc002) should equal(3) + m.readByte(0xc003) should equal(4) + m.readByte(0xc004) should equal(5) + } + } + test("Looping across arrays 3") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | + | array source = [1,2,3,4,5] + | array target[5] @$c000 + | void main() { + | byte i + | pointer.byte p + | for i,p:source { + | target[i] = p[0] + | } + | } + |""".stripMargin) { m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(2) + m.readByte(0xc002) should equal(3) + m.readByte(0xc003) should equal(4) + m.readByte(0xc004) should equal(5) + } + } + test("Looping across arrays 4") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)( + """ + | + | array target[5] @$c000 = [1,2,3,4,5] + | void main() { + | byte i + | pointer.byte p + | for i,p:target { + | p[0] += 1 + | } + | } + |""".stripMargin) { m => + m.readByte(0xc000) should equal(2) + m.readByte(0xc001) should equal(3) + m.readByte(0xc002) should equal(4) + m.readByte(0xc003) should equal(5) + m.readByte(0xc004) should equal(6) + } + } + }