From c9c0c16e98fce6c4338da4d79c97f2e1e5877f79 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 19 Mar 2018 21:58:51 +0100 Subject: [PATCH] Allowed more kinds of constants within variable and array initializers --- CHANGELOG.md | 2 + .../millfork/compiler/ReturnDispatch.scala | 54 ++++++++++++------- src/main/scala/millfork/env/Environment.scala | 31 +++++++---- src/main/scala/millfork/env/Thing.scala | 4 +- src/main/scala/millfork/node/Node.scala | 2 +- .../scala/millfork/output/Assembler.scala | 24 ++++++--- 6 files changed, 78 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 972981de..c063d834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ * `-fzp-register` is now enabled by default, as the documentation has already been saying. +* Allowed more kinds of constants within variable and array initializers. + * Fixed several bugs. * Other improvements. diff --git a/src/main/scala/millfork/compiler/ReturnDispatch.scala b/src/main/scala/millfork/compiler/ReturnDispatch.scala index 9a81f43c..fc3f8e90 100644 --- a/src/main/scala/millfork/compiler/ReturnDispatch.scala +++ b/src/main/scala/millfork/compiler/ReturnDispatch.scala @@ -53,25 +53,23 @@ object ReturnDispatch { } val returnType = ctx.function.returnType - val map = mutable.Map[Int, (Constant, List[Constant])]() + val map: mutable.Map[Int, (String, List[Expression])] = mutable.Map() var min = Option.empty[Int] var max = Option.empty[Int] - var default = Option.empty[(Constant, List[Constant])] + var default = Option.empty[(String, List[Expression])] stmt.branches.foreach { branch => - val function = ctx.env.evalForAsm(branch.function).getOrElse { - ErrorReporting.error("Undefined function or Non-constant function address for dispatch branch", branch.function.position) - Constant.Zero + val function: String = ctx.env.evalForAsm(branch.function) match { + case Some(MemoryAddressConstant(f: FunctionInMemory)) => + if (f.returnType.name != returnType.name) { + ErrorReporting.warn(s"Dispatching to a function of different return type: dispatcher return type: ${returnType.name}, dispatchee return type: ${f.returnType.name}", ctx.options, branch.function.position) + } + f.name + case _ => + ErrorReporting.error("Undefined function or Non-constant function address for dispatch branch", branch.function.position) + "" } - if (returnType.name != "void") { - function match { - case MemoryAddressConstant(f: FunctionInMemory) => - if (f.returnType.name != returnType.name) { - ErrorReporting.warn(s"Dispatching to a function of different return type: dispatcher return type: ${returnType.name}, dispatchee return type: ${f.returnType.name}", ctx.options, branch.function.position) - } - case _ => () - } - } - val params = branch.params.map(toConstant) + branch.params.foreach(toConstant) + val params = branch.params if (params.length > stmt.params.length) { ErrorReporting.error("Too many parameters for dispatch branch", branch.params.head.position) } @@ -102,7 +100,7 @@ object ReturnDispatch { } val actualMin = defaultMin min nonDefaultMin.getOrElse(defaultMin) val actualMax = defaultMax max nonDefaultMax.getOrElse(defaultMax) - val zeroes = Constant.Zero -> List[Constant]() + val zeroes = "" -> List[Expression]() for (i <- actualMin to actualMax) { if (!map.contains(i)) map(i) = default.getOrElse { // TODO: warning? @@ -125,7 +123,7 @@ object ReturnDispatch { val label = MfCompiler.nextLabel("di") val paramArrays = stmt.params.indices.map { ix => val a = InitializedArray(label + "$" + ix + ".array", None, (paramMins(ix) to paramMaxes(ix)).map { key => - map(key)._2.lift(ix).getOrElse(Constant.Zero) + map(key)._2.lift(ix).getOrElse(LiteralExpression(0, 1)) }.toList, ctx.function.declaredBank) env.registerUnnamedArray(a) @@ -147,7 +145,7 @@ object ReturnDispatch { } if (useJmpaix) { - val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(map(i)._1.loByte, map(i)._1.hiByte)).toList, ctx.function.declaredBank) + val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(lobyte0(map(i)._1), hibyte0(map(i)._1))).toList, ctx.function.declaredBank) env.registerUnnamedArray(jumpTable) if (copyParams.isEmpty) { val loadIndex = ExpressionCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.A, b)), BranchSpec.None) @@ -162,8 +160,8 @@ object ReturnDispatch { } } else { val loadIndex = ExpressionCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.X, b)), BranchSpec.None) - val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).loByte).toList, ctx.function.declaredBank) - val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).hiByte).toList, ctx.function.declaredBank) + val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => lobyte1(map(i)._1)).toList, ctx.function.declaredBank) + val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => hibyte1(map(i)._1)).toList, ctx.function.declaredBank) env.registerUnnamedArray(jumpTableLo) env.registerUnnamedArray(jumpTableHi) loadIndex ++ copyParams ++ List( @@ -174,4 +172,20 @@ object ReturnDispatch { AssemblyLine.implied(RTS)) } } + + private def lobyte0(fname: String): Expression = if (fname == "") LiteralExpression(0, 1) else { + FunctionCallExpression("lo",List(VariableExpression(fname + ".addr"))) + } + + private def hibyte0(fname: String): Expression = if (fname == "") LiteralExpression(0, 1) else { + FunctionCallExpression("hi",List(VariableExpression(fname + ".addr"))) + } + + private def lobyte1(fname: String): Expression = if (fname == "") LiteralExpression(0, 1) else { + FunctionCallExpression("lo", List(SumExpression(List(false -> VariableExpression(fname + ".addr"), true -> LiteralExpression(1, 1)), decimal = false))) + } + + private def hibyte1(fname: String): Expression = if (fname == "") LiteralExpression(0, 1) else { + FunctionCallExpression("hi", List(SumExpression(List(false -> VariableExpression(fname + ".addr"), true -> LiteralExpression(1, 1)), decimal = false))) + } } diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 440e17f0..f2558d5b 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -348,6 +348,20 @@ class Environment(val parent: Option[Environment], val prefix: String) { } yield hc.asl(8) + lc case FunctionCallExpression(name, params) => name match { + case "hi" => + if (params.size == 1) { + eval(params.head).map(_.hiByte.quickSimplify) + } else { + ErrorReporting.error("Invalid number of parameters for `hi`", e.position) + None + } + case "lo" => + if (params.size == 1) { + eval(params.head).map(_.loByte.quickSimplify) + } else { + ErrorReporting.error("Invalid number of parameters for `lo`", e.position) + None + } case "nonet" => params match { case List(FunctionCallExpression("<<", ps@List(_, _))) => @@ -358,7 +372,10 @@ class Environment(val parent: Option[Environment], val prefix: String) { constantOperation(MathOperator.Plus9, ps.map(_._2)) case List(SumExpression(ps@List((true,_),(true,_)), true)) => constantOperation(MathOperator.DecimalPlus9, ps.map(_._2)) + case List(_) => + None case _ => + ErrorReporting.error("Invalid number of parameters for `nonet`", e.position) None } case ">>'" => @@ -660,9 +677,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { val length = contents.length if (length > 0xffff || length < 0) ErrorReporting.error(s"Array `${stmt.name}` has invalid length", stmt.position) val address = stmt.address.map(a => eval(a).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant address", stmt.position))) - val data = contents.map(x => eval(x).getOrElse(Constant.error(s"Array `${stmt.name}` has non-constant contents", stmt.position))) - val array = InitializedArray(stmt.name + ".array", address, data, - declaredBank = stmt.bank) + val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank) addThing(array, stmt.position) registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, declaredBank = stmt.bank), stmt.position) @@ -748,9 +763,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { if (options.flags(CompilationFlag.ReadOnlyArrays)) { ErrorReporting.warn("Initialized variable in read-only segment", options, position) } - val ivc = eval(ive).getOrElse(Constant.error(s"Initial value of `$name` is not a constant", position)) - InitializedMemoryVariable(name, None, typ, ivc, - declaredBank = stmt.bank) + InitializedMemoryVariable(name, None, typ, ive, declaredBank = stmt.bank) } registerAddressConstant(v, stmt.position) (v, v.toAddress) @@ -815,8 +828,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { def collectDeclarations(program: Program, options: CompilationOptions): Unit = { if (options.flag(CompilationFlag.OptimizeForSonicSpeed)) { - addThing(InitializedArray("identity$", None, List.tabulate(256)(n => NumericConstant(n, 1)), - declaredBank = None), None) + addThing(InitializedArray("identity$", None, List.tabulate(256)(n => LiteralExpression(n, 1)), declaredBank = None), None) } program.declarations.foreach { case f: FunctionDeclarationStatement => registerFunction(f, options) @@ -838,8 +850,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { address = None), options) } if (!things.contains("__constant8")) { - things("__constant8") = InitializedArray("__constant8", None, List(NumericConstant(8, 1)), - declaredBank = None) + things("__constant8") = InitializedArray("__constant8", None, List(LiteralExpression(8, 1)), declaredBank = None) } } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index b63e61d6..2f561489 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -155,7 +155,7 @@ case class UninitializedMemoryVariable(name: String, typ: Type, alloc: VariableA override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) } -case class InitializedMemoryVariable(name: String, address: Option[Constant], typ: Type, initialValue: Constant, declaredBank: Option[String]) extends MemoryVariable with PreallocableThing { +case class InitializedMemoryVariable(name: String, address: Option[Constant], typ: Type, initialValue: Expression, declaredBank: Option[String]) extends MemoryVariable with PreallocableThing { override def zeropage: Boolean = false override def toAddress: MemoryAddressConstant = MemoryAddressConstant(this) @@ -185,7 +185,7 @@ case class RelativeArray(name: String, address: Constant, sizeInBytes: Int, decl override def bank(compilationOptions: CompilationOptions): String = declaredBank.getOrElse("default") } -case class InitializedArray(name: String, address: Option[Constant], contents: List[Constant], declaredBank: Option[String]) extends MfArray with PreallocableThing { +case class InitializedArray(name: String, address: Option[Constant], contents: List[Expression], declaredBank: Option[String]) extends MfArray with PreallocableThing { override def shouldGenerate = true override def isFar(compilationOptions: CompilationOptions): Boolean = farFlag.getOrElse(false) diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 565acb05..7ea5744a 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -1,7 +1,7 @@ package millfork.node import millfork.assembly.{AddrMode, Opcode} -import millfork.env.{Label, ParamPassingConvention} +import millfork.env.{Constant, Label, ParamPassingConvention} case class Position(filename: String, line: Int, column: Int, cursor: Int) diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index 0efa2fb3..769180dd 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -215,7 +215,10 @@ class Assembler(private val program: Program, private val rootEnv: Environment, assembly.append("* = $" + index.toHexString) assembly.append(name) for (item <- items) { - writeByte(bank, index, item) + env.eval(item) match { + case Some(c) => writeByte(bank, index, c) + case None => ErrorReporting.error(s"Non-constant contents of array `$name`") + } bank0.occupied(index) = true bank0.initialized(index) = true bank0.writeable(index) = true @@ -281,7 +284,10 @@ class Assembler(private val program: Program, private val rootEnv: Environment, assembly.append("* = $" + index.toHexString) assembly.append(name) for (item <- items) { - writeByte(bank, index, item) + env.eval(item) match { + case Some(c) => writeByte(bank, index, c) + case None => ErrorReporting.error(s"Non-constant contents of array `$name`") + } index += 1 } items.grouped(16).foreach {group => @@ -298,10 +304,16 @@ class Assembler(private val program: Program, private val rootEnv: Environment, env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer")) assembly.append("* = $" + index.toHexString) assembly.append(name) - for (i <- 0 until typ.size) { - writeByte(bank, index, value.subbyte(i)) - assembly.append(" !byte " + value.subbyte(i).quickSimplify) - index += 1 + env.eval(value) match { + case Some(c) => + for (i <- 0 until typ.size) { + writeByte(bank, index, c.subbyte(i)) + assembly.append(" !byte " + c.subbyte(i).quickSimplify) + index += 1 + } + case None => + ErrorReporting.error(s"Non-constant initial value for variable `$name`") + index += typ.size } initializedVariablesSize += typ.size justAfterCode += bank -> index