From 1ba4b57c1c41f6ec0e3ae75832451f10a75346bd Mon Sep 17 00:00:00 2001
From: Karol Stasiak <karol.m.stasiak@gmail.com>
Date: Mon, 15 Apr 2019 01:57:18 +0200
Subject: [PATCH] Array elements can now be types other than byte

---
 .../compiler/AbstractExpressionCompiler.scala |  7 ++++++
 src/main/scala/millfork/env/Environment.scala | 15 ++++++++---
 src/main/scala/millfork/node/Node.scala       |  1 +
 src/main/scala/millfork/parser/MfParser.scala |  6 +++--
 .../scala/millfork/test/TypedArraySuite.scala | 25 +++++++++++++++++++
 5 files changed, 48 insertions(+), 6 deletions(-)
 create mode 100644 src/test/scala/millfork/test/TypedArraySuite.scala

diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala
index 625989c4..4e7ada03 100644
--- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala
+++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala
@@ -358,6 +358,13 @@ object AbstractExpressionCompiler {
     }
   }
 
+  def checkAssignmentType(env: Environment, source: Expression, targetType: Type): Unit = {
+    val sourceType = getExpressionType(env, env.log, source)
+    if (!sourceType.isAssignableTo(targetType)) {
+      env.log.error(s"Cannot assign `$sourceType` to `$targetType`", source.position)
+    }
+  }
+
   def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = {
     val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x)
     env.lookupFunction(f.functionName, paramsWithTypes).getOrElse(
diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala
index daac42cb..916c5174 100644
--- a/src/main/scala/millfork/env/Environment.scala
+++ b/src/main/scala/millfork/env/Environment.scala
@@ -1058,6 +1058,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
     val b = get[VariableType]("byte")
     val w = get[VariableType]("word")
     val p = get[Type]("pointer")
+    val e = get[VariableType](stmt.elementType)
+    if (e.size != 1) {
+      log.error(s"Array elements should be of size 1, `${e.name}` is of size ${e.size}", stmt.position)
+    }
     stmt.elements match {
       case None =>
         stmt.length match {
@@ -1084,9 +1088,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
                 val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length))
                 val array = address match {
                   case None => UninitializedArray(stmt.name + ".array", length.toInt,
-                              declaredBank = stmt.bank, indexType, b, alignment)
+                              declaredBank = stmt.bank, indexType, e, alignment)
                   case Some(aa) => RelativeArray(stmt.name + ".array", aa, length.toInt,
-                              declaredBank = stmt.bank, indexType, b)
+                              declaredBank = stmt.bank, indexType, e)
                 }
                 addThing(array, stmt.position)
                 registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, stmt.bank, alignment, isVolatile = false), stmt.position, options)
@@ -1155,7 +1159,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
         if (length > 0xffff || length < 0) log.error(s"Array `${stmt.name}` has invalid length", stmt.position)
         val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length))
         val address = stmt.address.map(a => eval(a).getOrElse(errorConstant(s"Array `${stmt.name}` has non-constant address", stmt.position)))
-        val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, b, alignment)
+        for (element <- contents) {
+          AbstractExpressionCompiler.checkAssignmentType(this, element, e)
+        }
+        val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, e, alignment)
         addThing(array, stmt.position)
         registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None,
                     declaredBank = stmt.bank, alignment, isVolatile = false), stmt.position, options)
@@ -1163,7 +1170,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
           case None => array.toAddress
           case Some(aa) => aa
         }
-        addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false,
+        addThing(RelativeVariable(stmt.name + ".first", a, e, zeropage = false,
                     declaredBank = stmt.bank, isVolatile = false), stmt.position)
         if (options.flag(CompilationFlag.LUnixRelocatableCode)) {
           val b = get[Type]("byte")
diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala
index 5763a090..dbda16de 100644
--- a/src/main/scala/millfork/node/Node.scala
+++ b/src/main/scala/millfork/node/Node.scala
@@ -307,6 +307,7 @@ case class StructDefinitionStatement(name: String, fields: List[(String, String)
 case class ArrayDeclarationStatement(name: String,
                                      bank: Option[String],
                                      length: Option[Expression],
+                                     elementType: String,
                                      address: Option[Expression],
                                      elements: Option[ArrayContents],
                                      alignment: Option[MemoryAlignment]) extends DeclarationStatement {
diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala
index f30167a7..a71a7398 100644
--- a/src/main/scala/millfork/parser/MfParser.scala
+++ b/src/main/scala/millfork/parser/MfParser.scala
@@ -236,12 +236,14 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
   val arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for {
     p <- position()
     bank <- bankDeclaration
-    name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS
+    _ <- "array" ~ !letterOrDigit
+    elementType <- ("(" ~/ AWS ~/ identifier ~ AWS ~ ")").? ~/ HWS
+    name <- identifier ~ HWS
     length <- ("[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~ "]").? ~ HWS
     alignment <- alignmentDeclaration(fastAlignmentForFunctions).? ~/ HWS
     addr <- ("@" ~/ HWS ~/ mfExpression(1, false)).? ~/ HWS
     contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS
-  } yield Seq(ArrayDeclarationStatement(name, bank, length, addr, contents, alignment).pos(p))
+  } yield Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, contents, alignment).pos(p))
 
   def tightMfExpression(allowIntelHex: Boolean): P[Expression] = {
     val a = if (allowIntelHex) atomWithIntel else atom
diff --git a/src/test/scala/millfork/test/TypedArraySuite.scala b/src/test/scala/millfork/test/TypedArraySuite.scala
new file mode 100644
index 00000000..55cbca52
--- /dev/null
+++ b/src/test/scala/millfork/test/TypedArraySuite.scala
@@ -0,0 +1,25 @@
+package millfork.test
+
+import millfork.test.emu.EmuUnoptimizedCmosRun
+import org.scalatest.{FunSuite, Matchers}
+
+/**
+  * @author Karol Stasiak
+  */
+class TypedArraySuite extends FunSuite with Matchers {
+
+  test("Trivial assignments") {
+    val src =
+      """
+        | enum e {}
+        | array(e) output [3] @$c000
+        | void main () {
+        |   output[0] = e(1)
+        |   byte b
+        |   b = byte(output[0])
+        | }
+      """.stripMargin
+    val m = EmuUnoptimizedCmosRun(src)
+    m.readByte(0xc000) should equal(1)
+  }
+}