diff --git a/CHANGELOG.md b/CHANGELOG.md index 05cbe504..4ff3f806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Current version +* Added array initialization syntax with `for` (not yet finalized). + * Fixed several bugs, most importantly invalid offsets for branching instructions. ## 0.2.2 diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index f2558d5b..2a604c31 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -625,6 +625,29 @@ class Environment(val parent: Option[Environment], val prefix: String) { addThing(array, None) } + def extractArrayContents(contents1: ArrayContents): List[Expression] = contents1 match { + case LiteralContents(xs) => xs + case CombinedContents(xs) => xs.flatMap(extractArrayContents) + case ForLoopContents(v, start, end, direction, body) => + (eval(start), eval(end)) match { + case (Some(NumericConstant(s, sz1)), Some(NumericConstant(e, sz2))) => + val size = sz1 max sz2 + val range = (direction match { + case ForDirection.To | ForDirection.ParallelTo => s.to(e) + case ForDirection.Until | ForDirection.ParallelUntil => s.until(e) + case ForDirection.DownTo => s.to(e, -1) + }).toList + range.flatMap(i => extractArrayContents(body).map(_.replaceVariable(v, LiteralExpression(i, size)))) + case (Some(_), Some(_)) => + ErrorReporting.error("Array range bounds cannot be evaluated") + Nil + case _ => + ErrorReporting.error("Non-constant array range bounds") + Nil + + } + } + def registerArray(stmt: ArrayDeclarationStatement): Unit = { val b = get[Type]("byte") val p = get[Type]("pointer") @@ -663,7 +686,8 @@ class Environment(val parent: Option[Environment], val prefix: String) { case _ => ErrorReporting.error(s"Array `${stmt.name}` has weird length", stmt.position) } } - case Some(contents) => + case Some(contents1) => + val contents = extractArrayContents(contents1) stmt.length match { case None => case Some(l) => diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 7ea5744a..00fec4ed 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -104,12 +104,28 @@ case class VariableDeclarationStatement(name: String, override def getAllExpressions: List[Expression] = List(initialValue, address).flatten } +trait ArrayContents extends Node { + def getAllExpressions: List[Expression] +} + +case class LiteralContents(contents: List[Expression]) extends ArrayContents { + override def getAllExpressions: List[Expression] = contents +} + +case class ForLoopContents(variable: String, start: Expression, end: Expression, direction: ForDirection.Value, body: ArrayContents) extends ArrayContents { + override def getAllExpressions: List[Expression] = start :: end :: body.getAllExpressions.map(_.replaceVariable(variable, LiteralExpression(0, 1))) +} + +case class CombinedContents(contents: List[ArrayContents]) extends ArrayContents { + override def getAllExpressions: List[Expression] = contents.flatMap(_.getAllExpressions) +} + case class ArrayDeclarationStatement(name: String, bank: Option[String], length: Option[Expression], address: Option[Expression], - elements: Option[List[Expression]]) extends DeclarationStatement { - override def getAllExpressions: List[Expression] = List(length, address).flatten ++ elements.getOrElse(Nil) + elements: Option[ArrayContents]) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = List(length, address).flatten ++ elements.fold(List[Expression]())(_.getAllExpressions) } case class ParameterDeclaration(typ: String, diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala index f5637b66..4be2c9d7 100644 --- a/src/main/scala/millfork/node/opt/UnusedFunctions.scala +++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala @@ -56,7 +56,8 @@ object UnusedFunctions extends NodeOptimization { def getAllCalledFunctions(expressions: List[Node]): List[String] = expressions.flatMap { case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) - case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) + case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList) + case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) case s: ReturnDispatchStatement => diff --git a/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala index dac09bc1..993facb1 100644 --- a/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala +++ b/src/main/scala/millfork/node/opt/UnusedGlobalVariables.scala @@ -49,7 +49,8 @@ object UnusedGlobalVariables extends NodeOptimization { def getAllReadVariables(expressions: List[Node]): List[String] = expressions.flatMap { case s: VariableDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.initialValue.toList) - case s: ArrayDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.elements.getOrElse(Nil)) + case s: ArrayDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.elements.toList) + case s: ArrayContents => getAllReadVariables(s.getAllExpressions) case s: FunctionDeclarationStatement => getAllReadVariables(s.address.toList) ++ getAllReadVariables(s.statements.getOrElse(Nil)) case Assignment(VariableExpression(_), expr) => getAllReadVariables(expr :: Nil) case ExpressionStatement(FunctionCallExpression(op, VariableExpression(_) :: params)) if op.endsWith("=") => getAllReadVariables(params) diff --git a/src/main/scala/millfork/output/InliningCalculator.scala b/src/main/scala/millfork/output/InliningCalculator.scala index 9b776183..72a91dea 100644 --- a/src/main/scala/millfork/output/InliningCalculator.scala +++ b/src/main/scala/millfork/output/InliningCalculator.scala @@ -59,7 +59,8 @@ object InliningCalculator { case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) case ReturnDispatchStatement(index, params, branches) => getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function)) - case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) + case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.toList) + case s: ArrayContents => getAllCalledFunctions(s.getAllExpressions) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) case AssemblyStatement(JSR, _, VariableExpression(name), true) => (name -> false) :: Nil diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 9baf00de..d281968c 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -194,9 +194,9 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o appc <- appcSimple | appcComplex } yield ParameterDeclaration(typ, appc).pos(p) - def arrayListElement: P[List[Expression]] = arrayStringContents | mlExpression(nonStatementLevel).map(List(_)) + def arrayListElement: P[ArrayContents] = arrayStringContents | arrayLoopContents | mlExpression(nonStatementLevel).map(e => LiteralContents(List(e))) - def arrayListContents: P[List[Expression]] = ("[" ~/ AWS ~/ arrayListElement.rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ "]" ~/ Pass).map(_.flatten.toList) + def arrayListContents: P[ArrayContents] = ("[" ~/ AWS ~/ arrayListElement.rep(sep = AWS ~ "," ~/ AWS) ~ AWS ~ "]" ~/ Pass).map(c => CombinedContents(c.toList)) val doubleQuotedString: P[List[Char]] = P("\"" ~/ CharsWhile(c => c != '\"' && c != '\n' && c != '\r').! ~ "\"").map(_.toList) @@ -210,7 +210,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o TextCodec.Ascii } - def arrayFileContents: P[List[Expression]] = for { + def arrayFileContents: P[ArrayContents] = for { p <- "file" ~ HWS ~/ "(" ~/ HWS ~/ position() filePath <- doubleQuotedString ~/ HWS optSlice <- ("," ~/ HWS ~/ literalAtom ~/ HWS ~/ "," ~/ HWS ~/ literalAtom ~/ HWS ~/ Pass).? @@ -220,14 +220,34 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o val slice = optSlice.fold(data) { case (start, length) => data.slice(start.value.toInt, start.value.toInt + length.value.toInt) } - slice.map(c => LiteralExpression(c & 0xff, 1)).toList + LiteralContents(slice.map(c => LiteralExpression(c & 0xff, 1)).toList) } - def arrayStringContents: P[List[Expression]] = P(position() ~ doubleQuotedString ~/ HWS ~ codec).map { - case (p, s, co) => s.map(c => LiteralExpression(co.decode(None, c), 1).pos(p)) + def arrayStringContents: P[ArrayContents] = P(position() ~ doubleQuotedString ~/ HWS ~ codec).map { + case (p, s, co) => LiteralContents(s.map(c => LiteralExpression(co.decode(None, c), 1).pos(p))) } - def arrayContents: P[List[Expression]] = arrayListContents | arrayFileContents | arrayStringContents + def arrayLoopContents: P[ArrayContents] = for { + identifier <- "for" ~ SWS ~/ identifier ~/ HWS ~ "," ~/ HWS ~ Pass + start <- mlExpression(nonStatementLevel) ~ HWS ~ "," ~/ HWS ~/ Pass + pos <- position() + direction <- forDirection ~/ HWS ~/ "," ~/ HWS ~/ Pass + end <- mlExpression(nonStatementLevel) + body <- AWS ~ arrayContents + } yield { + val fixedDirection = direction match { + case ForDirection.ParallelUntil => + ErrorReporting.warn("`paralleluntil` is not allowed in array definitions, assuming `until`", options, Some(pos)) + ForDirection.Until + case ForDirection.ParallelTo => + ErrorReporting.warn("`parallelto` is not allowed in array definitions, assuming `to`", options, Some(pos)) + ForDirection.To + case x => x + } + ForLoopContents(identifier, start, end, fixedDirection, body) + } + + def arrayContents: P[ArrayContents] = arrayListContents | arrayLoopContents | arrayFileContents | arrayStringContents def arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for { p <- position() diff --git a/src/test/scala/millfork/test/ForArraySuite.scala b/src/test/scala/millfork/test/ForArraySuite.scala new file mode 100644 index 00000000..0198ad71 --- /dev/null +++ b/src/test/scala/millfork/test/ForArraySuite.scala @@ -0,0 +1,43 @@ +package millfork.test + +import millfork.{Cpu, OptimizationPresets} +import millfork.assembly.opt.{AlwaysGoodOptimizations, DangerousOptimizations} +import millfork.test.emu._ +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ForArraySuite extends FunSuite with Matchers { + + test("Basic for-array test") { + val m = EmuSuperOptimizedRun( + """ + | byte output @$c000 + | array input = for i,0,until,8 [i + i] + | + | array useless0 = for x,0,until,8 for y,0,until,8 [0] + | array useless1 = for x,0,until,8 [for y,0,until,8 [0]] + | array useless2 = for x,0,until,8 "test" scr + | array useless3 = for x,0,until,8 [1 << x] + | array useless4 = for x,0,until,4 [(3 << (x * 2)) ^ 0xff] + | array useless5 = [ + | for x,0,until,4 [3] + | ] + | array useless6 = [ + | "foo" ascii, + | for x,0,until,4 [3] + | ] + | array useless7 = [ + | 7, + | for x,0,until,4 [3] + | ] + | + | void main () { + | output = useless0[0] + useless1[0] + useless2[0] + useless3[0] + useless4[0] + useless5[0] + useless6[0] + useless7[0] + | output = input.length + input[5] + | } + """.stripMargin) + m.readByte(0xc000) should equal(18) + } +}