diff --git a/docs/lang/syntax.md b/docs/lang/syntax.md index ca713883..4c8b9af0 100644 --- a/docs/lang/syntax.md +++ b/docs/lang/syntax.md @@ -195,15 +195,18 @@ return [i] (param1, param2) { Return dispatch calculates the value of an index, picks the correct branch, assigns some global variables and jumps to another function. -The index has to evaluate to a byte. The functions cannot be `macro` and shouldn't have parameters. +The index has to evaluate to a byte or to an enum. The functions cannot be `macro` and shouldn't have parameters. Jumping to a function with parameters gives those parameters undefined values. The functions are not called, so they don't return to the function the return dispatch statement is in, but to its caller. The return values are passed along. If the dispatching function has a non-`void` return type different that the type of the function dispatched to, the return value is undefined. -If the `default` branch exists, then it is used for every missing index value between other supported values. -Optional parameters to `default` specify the maximum, or both the minimum and maximum supported index value. +If the `default` branch exists, then it is used for every missing index. +If the index type is an non-empty enum, then the default branch supports all the other values. +Otherwise, the `default` branch handles only the missing values between other supported values. +In this case, you can override it with optional parameters to `default`. +They specify the maximum, or both the minimum and maximum supported index value. In the above examples: the first example supports values 0–255, second 1–5, and third 0–20. If the index has an unsupported value, the behaviour is formally undefined, but in practice the program will simply crash. diff --git a/src/main/scala/millfork/compiler/AbstractReturnDispatch.scala b/src/main/scala/millfork/compiler/AbstractReturnDispatch.scala index 0096d2d3..e5129510 100644 --- a/src/main/scala/millfork/compiler/AbstractReturnDispatch.scala +++ b/src/main/scala/millfork/compiler/AbstractReturnDispatch.scala @@ -66,6 +66,19 @@ abstract class AbstractReturnDispatch[T <: AbstractCode] { var max = Option.empty[Int] var default = Option.empty[(Option[ThingInMemory], List[Expression])] stmt.branches.foreach { branch => + branch.label match { + case s: StandardReturnDispatchLabel => + for (label <- s.labels) { + verifyLabelCompatibility(ctx, indexerType, s, label) + } + case s@DefaultReturnDispatchLabel(start, end) => + for (label <- start) { + verifyLabelCompatibility(ctx, indexerType, s, label) + } + for (label <- end) { + verifyLabelCompatibility(ctx, indexerType, s, label) + } + } val function: String = ctx.env.evalForAsm(branch.function) match { case Some(MemoryAddressConstant(f: FunctionInMemory)) => if (f.returnType.name != returnType.name) { @@ -82,12 +95,21 @@ abstract class AbstractReturnDispatch[T <: AbstractCode] { ctx.log.error("Too many parameters for dispatch branch", branch.params.head.position) } branch.label match { - case DefaultReturnDispatchLabel(start, end) => + case s@DefaultReturnDispatchLabel(start, end) => if (default.isDefined) { ctx.log.error(s"Duplicate default dispatch label", branch.position) } - min = start.map(toInt) - max = end.map(toInt) + indexerType match { + case EnumType(_, Some(count)) => + if (start.isDefined || end.isDefined) { + ctx.log.error("Return dispatch over non-empty enum cannot have a different default range", s.position) + } + min = Some(0) + max = Some(count - 1) + case _ => + min = start.map(toInt) + max = end.map(toInt) + } default = Some(Some(ctx.env.get[FunctionInMemory](function)) -> params) case StandardReturnDispatchLabel(labels) => labels.foreach { label => @@ -140,14 +162,31 @@ abstract class AbstractReturnDispatch[T <: AbstractCode] { compileImpl(ctx, stmt, label, actualMin, actualMax, paramArrays, paramMins, map) } + private def verifyLabelCompatibility(ctx: CompilationContext, indexerType: Type, s: ReturnDispatchLabel, label: Expression): Unit = { + val labelType = AbstractExpressionCompiler.getExpressionType(ctx, label) + val bad = areIncompatible(indexerType, labelType) + if (bad) { + ctx.log.error(s"Incompatible return dispatch label type: expected `${indexerType.name}`, got `${labelType.name}`", label.position.orElse(s.position)) + } + } + + private def areIncompatible(indexerType: Type, labelType: Type) = { + (indexerType, labelType) match { + case (EnumType(n1, _), EnumType(n2, _)) => n1 != n2 + case (_, EnumType(n2, _)) => true + case (EnumType(n1, _), _) => true + case _ => false + } + } + def compileImpl(ctx: CompilationContext, - stmt: ReturnDispatchStatement, - label: String, - actualMin: Int, - actualMax: Int, - paramArrays: IndexedSeq[InitializedArray], - paramMins: IndexedSeq[Int], - map: mutable.Map[Int, (Option[ThingInMemory], List[Expression])]): List[T] + stmt: ReturnDispatchStatement, + label: String, + actualMin: Int, + actualMax: Int, + paramArrays: IndexedSeq[InitializedArray], + paramMins: IndexedSeq[Int], + map: mutable.Map[Int, (Option[ThingInMemory], List[Expression])]): List[T] protected def zeroOr(function: Option[ThingInMemory])(F: ThingInMemory => Constant): Expression = function.fold[Expression](LiteralExpression(0, 1))(F andThen ConstantArrayElementExpression) diff --git a/src/test/scala/millfork/test/ReturnDispatchSuite.scala b/src/test/scala/millfork/test/ReturnDispatchSuite.scala index 5b720524..d40b93d9 100644 --- a/src/test/scala/millfork/test/ReturnDispatchSuite.scala +++ b/src/test/scala/millfork/test/ReturnDispatchSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.EmuCrossPlatformBenchmarkRun +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedIntel8080Run, ShouldNotCompile} import org.scalatest.{FunSuite, Matchers} /** @@ -72,4 +72,99 @@ class ReturnDispatchSuite extends FunSuite with Matchers { m.readByte(0xc002) should equal(1) } } + + test("Enum test") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Cmos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | byte output @$c000 + | enum ugly { + | a + | b,c, + | d + | } + | void main () { + | ugly i + | i = a + | return [i] { + | a @ success + | } + | } + | void success() { + | output = 42 + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(42) + } + } + + test("Mixed types shouldn't compile 1") { + ShouldNotCompile( + """ + | enum ugly {a,b,c,d} + | void main () { + | ugly i + | return [i] { + | 1 @ success + | } + | } + | void success() {} + """.stripMargin) + } + + test("Mixed types shouldn't compile 2") { + ShouldNotCompile( + """ + | enum ugly {a,b,c,d} + | void main () { + | byte i + | return [i] { + | a @ success + | } + | } + | void success() {} + """.stripMargin) + } + + test("Non-empty enums can't have defined default ranges") { + ShouldNotCompile( + """ + | enum ugly {a,b,c,d} + | void main () { + | ugly i + | return [i] { + | a @ success + | default(a,d) @ success + | } + | } + | void success() {} + """.stripMargin) + } + + test("Empty enums can have defined default ranges") { + EmuUnoptimizedIntel8080Run( + """ + | enum ugly {} + | void main () { + | ugly i + | return [i] { + | default(ugly(0), ugly(10)) @ success + | } + | } + | void success() {} + """.stripMargin) + } + + test("Non-empty enums can have implied default ranges") { + EmuUnoptimizedIntel8080Run( + """ + | enum ugly {a,b,c} + | void main () { + | ugly i + | return [i] { + | default @ success + | } + | } + | void success() {} + """.stripMargin) + } }