diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 1fb4dc4f..e247158f 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -112,9 +112,34 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { // TODO: special faster cases val p = f.position val vex = VariableExpression(f.variable) + val indexType = ctx.env.get[Variable](f.variable).typ + val arithmetic = indexType.isArithmetic + if (!arithmetic && f.direction != ForDirection.ParallelUntil) { + ctx.log.error("Invalid direction for enum iteration", p) + compile(ctx, f.body) + return Nil -> Nil + } val one = LiteralExpression(1, 1).pos(p) - val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one)).pos(p)).pos(p) - val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one)).pos(p)).pos(p) + val increment = if (arithmetic) { + ExpressionStatement(FunctionCallExpression("+=", List(vex, one)).pos(p)).pos(p) + } else { + Assignment(vex, FunctionCallExpression(indexType.name, List( + SumExpression(List( + false -> FunctionCallExpression("byte", List(vex)).pos(p), + false -> LiteralExpression(1,1).pos(p), + ), decimal = false).pos(p) + )).pos(p)).pos(p) + } + val decrement = if (arithmetic) { + ExpressionStatement(FunctionCallExpression("-=", List(vex, one)).pos(p)).pos(p) + } else { + Assignment(vex, FunctionCallExpression(indexType.name, List( + SumExpression(List( + false -> FunctionCallExpression("byte", List(vex)).pos(p), + true -> LiteralExpression(1,1).pos(p), + ), decimal = false).pos(p) + )).pos(p)).pos(p) + } val names = Set("", "for", f.variable) val startEvaluated = ctx.env.eval(f.start) diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index c3d8c3ae..a6e99a3d 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -432,7 +432,7 @@ case class ForStatement(variable: String, start: Expression, end: Expression, di } case class ForEachStatement(variable: String, values: Either[Expression, List[Expression]], body: List[ExecutableStatement]) extends CompoundStatement { - override def getAllExpressions: List[Expression] = VariableExpression(variable) :: (values.fold(List(_), identity) ++ body.flatMap(_.getAllExpressions)) + override def getAllExpressions: List[Expression] = VariableExpression(variable) :: body.flatMap(_.getAllExpressions) override def getChildStatements: Seq[Statement] = body override def flatMap(f: ExecutableStatement => Option[ExecutableStatement]): Option[ExecutableStatement] = { diff --git a/src/test/scala/millfork/test/EnumSuite.scala b/src/test/scala/millfork/test/EnumSuite.scala index e684f539..8d6f28bb 100644 --- a/src/test/scala/millfork/test/EnumSuite.scala +++ b/src/test/scala/millfork/test/EnumSuite.scala @@ -66,6 +66,22 @@ class EnumSuite extends FunSuite with Matchers { """.stripMargin){_=>} } + test("Loops over enums") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | enum ugly { + | a + | b,c, + | d + | } + | void main() { + | ugly u + | for u:ugly { + | } + | } + """.stripMargin){_=>} + } + test("Enum-byte incompatibility test") { ShouldNotCompile( """ @@ -151,5 +167,15 @@ class EnumSuite extends FunSuite with Matchers { | return a[0] | } """.stripMargin) + + ShouldNotCompile( + """ + | enum ugly { a } + | array arr[ugly] + | void main() { + | byte x + | for x: ugly {} + | } + """.stripMargin) } }