diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 625989c4..4e7ada03 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -358,6 +358,13 @@ object AbstractExpressionCompiler { } } + def checkAssignmentType(env: Environment, source: Expression, targetType: Type): Unit = { + val sourceType = getExpressionType(env, env.log, source) + if (!sourceType.isAssignableTo(targetType)) { + env.log.error(s"Cannot assign `$sourceType` to `$targetType`", source.position) + } + } + def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = { val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x) env.lookupFunction(f.functionName, paramsWithTypes).getOrElse( diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index daac42cb..916c5174 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -1058,6 +1058,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val b = get[VariableType]("byte") val w = get[VariableType]("word") val p = get[Type]("pointer") + val e = get[VariableType](stmt.elementType) + if (e.size != 1) { + log.error(s"Array elements should be of size 1, `${e.name}` is of size ${e.size}", stmt.position) + } stmt.elements match { case None => stmt.length match { @@ -1084,9 +1088,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length)) val array = address match { case None => UninitializedArray(stmt.name + ".array", length.toInt, - declaredBank = stmt.bank, indexType, b, alignment) + declaredBank = stmt.bank, indexType, e, alignment) case Some(aa) => RelativeArray(stmt.name + ".array", aa, length.toInt, - declaredBank = stmt.bank, indexType, b) + declaredBank = stmt.bank, indexType, e) } addThing(array, stmt.position) registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, stmt.bank, alignment, isVolatile = false), stmt.position, options) @@ -1155,7 +1159,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (length > 0xffff || length < 0) log.error(s"Array `${stmt.name}` has invalid length", stmt.position) val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length)) val address = stmt.address.map(a => eval(a).getOrElse(errorConstant(s"Array `${stmt.name}` has non-constant address", stmt.position))) - val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, b, alignment) + for (element <- contents) { + AbstractExpressionCompiler.checkAssignmentType(this, element, e) + } + val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, e, alignment) addThing(array, stmt.position) registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, declaredBank = stmt.bank, alignment, isVolatile = false), stmt.position, options) @@ -1163,7 +1170,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case None => array.toAddress case Some(aa) => aa } - addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false, + addThing(RelativeVariable(stmt.name + ".first", a, e, zeropage = false, declaredBank = stmt.bank, isVolatile = false), stmt.position) if (options.flag(CompilationFlag.LUnixRelocatableCode)) { val b = get[Type]("byte") diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 5763a090..dbda16de 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -307,6 +307,7 @@ case class StructDefinitionStatement(name: String, fields: List[(String, String) case class ArrayDeclarationStatement(name: String, bank: Option[String], length: Option[Expression], + elementType: String, address: Option[Expression], elements: Option[ArrayContents], alignment: Option[MemoryAlignment]) extends DeclarationStatement { diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index f30167a7..a71a7398 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -236,12 +236,14 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri val arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for { p <- position() bank <- bankDeclaration - name <- "array" ~ !letterOrDigit ~/ SWS ~ identifier ~ HWS + _ <- "array" ~ !letterOrDigit + elementType <- ("(" ~/ AWS ~/ identifier ~ AWS ~ ")").? ~/ HWS + name <- identifier ~ HWS length <- ("[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~ "]").? ~ HWS alignment <- alignmentDeclaration(fastAlignmentForFunctions).? ~/ HWS addr <- ("@" ~/ HWS ~/ mfExpression(1, false)).? ~/ HWS contents <- ("=" ~/ HWS ~/ arrayContents).? ~/ HWS - } yield Seq(ArrayDeclarationStatement(name, bank, length, addr, contents, alignment).pos(p)) + } yield Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, contents, alignment).pos(p)) def tightMfExpression(allowIntelHex: Boolean): P[Expression] = { val a = if (allowIntelHex) atomWithIntel else atom diff --git a/src/test/scala/millfork/test/TypedArraySuite.scala b/src/test/scala/millfork/test/TypedArraySuite.scala new file mode 100644 index 00000000..55cbca52 --- /dev/null +++ b/src/test/scala/millfork/test/TypedArraySuite.scala @@ -0,0 +1,25 @@ +package millfork.test + +import millfork.test.emu.EmuUnoptimizedCmosRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class TypedArraySuite extends FunSuite with Matchers { + + test("Trivial assignments") { + val src = + """ + | enum e {} + | array(e) output [3] @$c000 + | void main () { + | output[0] = e(1) + | byte b + | b = byte(output[0]) + | } + """.stripMargin + val m = EmuUnoptimizedCmosRun(src) + m.readByte(0xc000) should equal(1) + } +}