diff --git a/docs/abi/variable-storage.md b/docs/abi/variable-storage.md index 942accee..c63645b2 100644 --- a/docs/abi/variable-storage.md +++ b/docs/abi/variable-storage.md @@ -46,6 +46,8 @@ but the main disadvantages are: * cannot use them in inline assembly code blocks +* structs and unions containing array fields are not supported + The implementation depends on the target architecture: * on 6502, the stack pointer is transferred into the X register and used as a base diff --git a/docs/lang/types.md b/docs/lang/types.md index dc1127ed..67a31c8f 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -256,11 +256,24 @@ as there are no checks on values when converting bytes to enumeration values and ## Structs -Struct is a compound type containing multiple fields of various types: +Struct is a compound type containing multiple fields of various types. +A struct is represented in memory as a contiguous area of variables or arrays laid out one after another. - struct [align (alignment)] { } +Declaration syntax: -A struct is represented in memory as a contiguous area of variables laid out one after another. + struct [align (alignment)] { } + +where a field definition is either: + +* ` ` and defines a scalar field, + +* or `array () []`, which defines an array field, + where the array contains items of type ``, + and either contains `` elements + if `` is a constant expression between 0 and 127, + or, if `` is a plain enumeration type, the array is indexed by that type, + and the number of elements is equal to the number of variants in that enumeration. + `()` can be omitted and defaults to `byte`. Struct can have a maximum size of 255 bytes. Larger structs are not supported. @@ -290,8 +303,8 @@ All arguments to the constructor must be constant. Structures declared with an alignment are allocated at appropriate memory addresses. The alignment has to be a power of two. -If the structs are in an array, they are padded with unused bytes. -If the struct is smaller that its alignment, then arrays of it are faster +If the structs with declared alignment are in an array, they are padded with unused bytes. +If the struct is smaller that its alignment, then arrays of it are faster than if it were not aligned struct a align(4) { byte x,byte y, byte z } struct b { byte x,byte y, byte z } @@ -309,9 +322,15 @@ If the struct is smaller that its alignment, then arrays of it are faster A struct that contains substructs or subunions with non-trivial alignments has its alignment equal to the least common multiple of the alignments of the substructs and its own declared alignment. +**Warning:** Limitations of array fields: + +* Structs containing arrays cannot be allocated on the stack. + +* Struct constructors for structs with array fields are not supported. + ## Unions - union [align (alignment)] { } + union [align (alignment)] { } Unions are pretty similar to structs, with the difference that all fields of the union start at the same point in memory and therefore overlap each other. @@ -327,3 +346,5 @@ start at the same point in memory and therefore overlap each other. Offset constants are also available, but they're obviously all zero. Unions currently do not have an equivalent of struct constructors. This may be improved on in the future. + +Unions with array fields have the same limitations as structs with array fields. diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index d77930df..3acf2864 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -378,7 +378,7 @@ object AbstractExpressionCompiler { log.error(s"Type `$targetType` doesn't have field named `$actualFieldName`", expr.position) ok = false } else { - if (tuples.head.arraySize.isDefined) ??? // TODO + if (tuples.head.arrayIndexTypeAndSize.isDefined) ??? // TODO pointerWrap match { case 0 => currentType = tuples.head.typ diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index a1eb6741..971b9868 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -13,6 +13,9 @@ import scala.collection.mutable.ListBuffer * @author Karol Stasiak */ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationContext, statements: List[ExecutableStatement]) { + implicit class StringToFunctionNameOps(val functionName: String) { + def <|(exprs: Expression*): Expression = FunctionCallExpression(functionName, exprs.toList).pos(exprs.head.position) + } type VV = Map[String, Constant] protected val optimize = true // TODO protected val env: Environment = ctx.env @@ -387,9 +390,6 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte case _ => } } - implicit class StringToFunctionNameOps(val functionName: String) { - def <|(exprs: Expression*): Expression = FunctionCallExpression(functionName, exprs.toList).pos(exprs.head.position) - } // generic warnings: expr match { case FunctionCallExpression("*" | "*=", params) => @@ -411,16 +411,32 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte val b = env.get[Type]("byte") var ok = true var result = optimizeExpr(root, currentVarValues).pos(pos) - def applyIndex(result: Expression, index: Expression): Expression = { + def applyIndex(result: Expression, index: Expression, guaranteedSmall: Boolean): Expression = { AbstractExpressionCompiler.getExpressionType(env, env.log, result) match { - case pt@PointerType(_, _, Some(target)) => - env.eval(index) match { - case Some(NumericConstant(0, _)) => //ok + case pt@PointerType(_, _, Some(targetType)) => + val zero = env.eval(index) match { + case Some(NumericConstant(0, _)) => + true case _ => - // TODO: should we keep this? - env.log.error(s"Type `$pt` can be only indexed with 0") + false + } + if (zero) { + DerefExpression(result, 0, targetType) + } else { + val indexType = AbstractExpressionCompiler.getExpressionType(env, env.log, index) + env.eval(index) match { + case Some(NumericConstant(n, _)) if n >= 0 && (guaranteedSmall || (targetType.alignedSize * n) <= 127) => + DerefExpression( + ("pointer." + targetType.name) <| result, + targetType.alignedSize * n.toInt, targetType) + case _ => + val small = guaranteedSmall || (indexType.size == 1 && !indexType.isSigned) + val scaledIndex: Expression = scaleIndexForArrayAccess(index, targetType, if (small) Some(256) else None) + DerefExpression(("pointer." + targetType.name) <| ( + ("pointer" <| result) #+# optimizeExpr(scaledIndex, Map()) + ), 0, targetType) + } } - DerefExpression(result, 0, target) case x if x.isPointy => val (targetType, arraySizeInBytes) = result match { case VariableExpression(maybePointy) => @@ -443,33 +459,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte targetType.alignedSize * n.toInt, targetType) } case _ => - val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize) - val shrunkElementSize = targetType.alignedSize >> shifts - val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts)) - val scaledIndex = arraySizeInBytes match { - // "n > targetType.alignedSize" means - // "don't do optimizations on arrays size 0 or 1" - case Some(n) if n > targetType.alignedSize && n <= 256 => targetType.alignedSize match { - case 1 => "byte" <| index - case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1)) - case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1)) - case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1)) - } - case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 => - "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) - case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 => - "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) - case Some(n) if n > targetType.alignedSize && shrunkArraySize <= 256 => - "<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1)) - case _ => targetType.alignedSize match { - case 1 => "word" <| index - case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1)) - case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1)) - case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1)) - } - } + val scaledIndex: Expression = scaleIndexForArrayAccess(index, targetType, arraySizeInBytes) // TODO: re-cast pointer type DerefExpression(("pointer." + targetType.name) <| ( result #+# optimizeExpr(scaledIndex, Map()) @@ -483,8 +473,9 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte } for (index <- firstIndices) { - result = applyIndex(result, index) + result = applyIndex(result, index, guaranteedSmall = false) } + var guaranteedSmall = false for ((dot, fieldName, indices) <- fieldPath) { if (dot && ok) { val pointer = result match { @@ -527,45 +518,79 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte ok = false LiteralExpression(0, 1) } else { - if (subvariables.head.arraySize.isDefined) ??? // TODO val inner = optimizeExpr(result, currentVarValues, optimizeSum = true).pos(pos) - val fieldOffset = subvariables.head.offset - val fieldType = subvariables.head.typ - pointerWrap match { - case 0 => - DerefExpression(inner, fieldOffset, fieldType) - case 1 => - if (fieldOffset == 0) { - ("pointer." + fieldType.name) <| ("pointer" <| inner) - } else { - ("pointer." + fieldType.name) <| ( - ("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2) - ) - } - case 2 => - if (fieldOffset == 0) { - "pointer" <| inner - } else { - ("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2) - } - case 10 => - if (fieldOffset == 0) { - "lo" <| ("pointer" <| inner) - } else { - "lo" <| ( - ("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2) - ) - } - case 11 => - if (fieldOffset == 0) { - "hi" <| ("pointer" <| inner) - } else { - "hi" <| ( - ("pointer" <| inner) #+# LiteralExpression(fieldOffset, 2) - ) + val subvariable = subvariables.head + val fieldOffset = subvariable.offset + val fieldType = subvariable.typ + val offsetExpression = LiteralExpression(fieldOffset, 2).pos(pos) + subvariable.arrayIndexTypeAndSize match { + case Some((indexType, arraySize)) => + guaranteedSmall = arraySize * target.alignedSize <= 256 + pointerWrap match { + case 0 | 1 => + if (fieldOffset == 0) { + ("pointer." + fieldType.name) <| ("pointer" <| inner) + } else { + ("pointer." + fieldType.name) <| (("pointer" <| inner) #+# offsetExpression) + } + case 2 => + if (fieldOffset == 0) { + ("pointer" <| inner) + } else { + ("pointer" <| inner) #+# offsetExpression + } + case 10 => + if (fieldOffset == 0) { + "lo" <| ("pointer" <| inner) + } else { + "lo" <| (("pointer" <| inner) #+# offsetExpression) + } + case 11 => + if (fieldOffset == 0) { + "hi" <| (("pointer" <| inner)) + } else { + "hi" <| (("pointer" <| inner) #+# offsetExpression) + } + case _ => throw new IllegalStateException } + case None => + guaranteedSmall = false + pointerWrap match { + case 0 => + DerefExpression(inner, fieldOffset, fieldType) + case 1 => + if (fieldOffset == 0) { + ("pointer." + fieldType.name) <| ("pointer" <| inner) + } else { + ("pointer." + fieldType.name) <| ( + ("pointer" <| inner) #+# offsetExpression + ) + } + case 2 => + if (fieldOffset == 0) { + "pointer" <| inner + } else { + ("pointer" <| inner) #+# offsetExpression + } + case 10 => + if (fieldOffset == 0) { + "lo" <| ("pointer" <| inner) + } else { + "lo" <| ( + ("pointer" <| inner) #+# offsetExpression + ) + } + case 11 => + if (fieldOffset == 0) { + "hi" <| ("pointer" <| inner) + } else { + "hi" <| ( + ("pointer" <| inner) #+# offsetExpression + ) + } - case _ => throw new IllegalStateException + case _ => throw new IllegalStateException + } } } case _ => @@ -576,7 +601,8 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte } if (ok) { for (index <- indices) { - result = applyIndex(result, index) + result = applyIndex(result, index, guaranteedSmall) + guaranteedSmall = false } } } @@ -710,6 +736,37 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte } } + private def scaleIndexForArrayAccess(index: Expression, targetType: Type, arraySizeInBytes: Option[Int]): Expression = { + val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize) + val shrunkElementSize = targetType.alignedSize >> shifts + val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts)) + val scaledIndex = arraySizeInBytes match { + // "n > targetType.alignedSize" means + // "don't do optimizations on arrays size 0 or 1" + case Some(n) if n > targetType.alignedSize && n <= 256 => targetType.alignedSize match { + case 1 => "byte" <| index + case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1)) + case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1)) + case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1)) + case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1)) + } + case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 => + "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) + case Some(n) if n > targetType.alignedSize && n <= 512 && targetType.alignedSize == 2 => + "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) + case Some(n) if n > targetType.alignedSize && shrunkArraySize <= 256 => + "<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1)) + case _ => targetType.alignedSize match { + case 1 => "word" <| index + case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1)) + case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1)) + case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1)) + case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1)) + } + } + scaledIndex + } + def pointlessCast(t1: String, expr: Expression): Boolean = { val typ1 = env.maybeGet[Type](t1).getOrElse(return false) val typ2 = getExpressionType(ctx, expr) diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 97ca1510..d111efb8 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -221,27 +221,27 @@ case class StructureConstant(typ: StructType, fields: List[Constant]) extends Co override def subbyte(index: Int): Constant = { var offset = 0 - for ((fv, ResolvedFieldDesc(ft, _, arraySize)) <- fields.zip(typ.mutableFieldsWithTypes)) { + for ((fv, ResolvedFieldDesc(ft, _, arrayIndexTypeAndSize)) <- fields.zip(typ.mutableFieldsWithTypes)) { // TODO: handle array members? val fs = ft.size if (index < offset + fs) { val indexInField = index - offset return fv.subbyte(indexInField) } - offset += fs * arraySize.getOrElse(1) + offset += fs * arrayIndexTypeAndSize.fold(1)(_._2) } Constant.Zero } override def subbyteBe(index: Int, totalSize: Int): Constant = { var offset = 0 - for ((fv, ResolvedFieldDesc(ft, _, arraySize)) <- fields.zip(typ.mutableFieldsWithTypes)) { + for ((fv, ResolvedFieldDesc(ft, _, arrayIndexTypeAndSize)) <- fields.zip(typ.mutableFieldsWithTypes)) { // TODO: handle array members? val fs = ft.size if (index < offset + fs) { val indexInField = index - offset return fv.subbyteBe(indexInField, fs) } - offset += fs * arraySize.getOrElse(1) + offset += fs * arrayIndexTypeAndSize.fold(1)(_._2) } Constant.Zero } diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index d667601b..8a35460b 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -896,6 +896,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (params.size == t.fields.size) { sequence(params.map(p => evalImpl(p, vv))).map(fields => StructureConstant(t, fields.zip(t.fields).map{ case (fieldConst, fieldDesc) => + if (fieldDesc.arraySize.isDefined) { + log.error(s"Cannot define a struct literal for a struct type ${t.name} with array fields", fce.position) + } fieldConst.fitInto(get[Type](fieldDesc.typeName)) })) } else None @@ -1172,8 +1175,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa else { val newPath = path + name var sum = 0 - for( ResolvedFieldDesc(fieldType, _, count) <- s.mutableFieldsWithTypes) { - val fieldSize = getTypeSize(fieldType, newPath) * count.getOrElse(1) + for( ResolvedFieldDesc(fieldType, _, indexTypeAndCount) <- s.mutableFieldsWithTypes) { + val fieldSize = getTypeSize(fieldType, newPath) * indexTypeAndCount.fold(1)(_._2) if (fieldSize < 0) return -1 sum = fieldType.alignment.roundSizeUp(sum) sum += fieldSize @@ -1185,10 +1188,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } val b = get[Type]("byte") var offset = 0 - for( ResolvedFieldDesc(fieldType, fieldName, count) <- s.mutableFieldsWithTypes) { + for( ResolvedFieldDesc(fieldType, fieldName, indexTypeAndCount) <- s.mutableFieldsWithTypes) { offset = fieldType.alignment.roundSizeUp(offset) addThing(ConstantThing(s"$name.$fieldName.offset", NumericConstant(offset, 1), b), None) - offset += getTypeSize(fieldType, newPath) * count.getOrElse(1) + offset += getTypeSize(fieldType, newPath) * indexTypeAndCount.fold(1)(_._2) offset = fieldType.alignment.roundSizeUp(offset) } sum @@ -1198,8 +1201,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa else { val newPath = path + name var max = 0 - for( ResolvedFieldDesc(fieldType, _, count) <- s.mutableFieldsWithTypes) { - val fieldSize = getTypeSize(fieldType, newPath) * count.getOrElse(1) + for( ResolvedFieldDesc(fieldType, _, indexTypeAndCount) <- s.mutableFieldsWithTypes) { + val fieldSize = getTypeSize(fieldType, newPath) * indexTypeAndCount.fold(1)(_._2) if (fieldSize < 0) return -1 max = max max fieldSize } @@ -1934,37 +1937,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case None => array.toAddress case Some(aa) => aa } - addThing(RelativeVariable(arrayName + ".first", a, b, zeropage = false, - declaredBank = stmt.bank, isVolatile = false), stmt.position) - if (options.flag(CompilationFlag.LUnixRelocatableCode)) { - val b = get[Type]("byte") - val w = get[Type]("word") - val relocatable = UninitializedMemoryVariable(arrayName, w, VariableAllocationMethod.Static, None, Set.empty, NoAlignment, isVolatile = false) - val addr = relocatable.toAddress - addThing(relocatable, stmt.position) - addThing(RelativeVariable(arrayName + ".addr.hi", addr + 1, b, zeropage = false, None, isVolatile = false), stmt.position) - addThing(RelativeVariable(arrayName + ".addr.lo", addr, b, zeropage = false, None, isVolatile = false), stmt.position) - addThing(RelativeVariable(arrayName + ".array.hi", addr + 1, b, zeropage = false, None, isVolatile = false), stmt.position) - addThing(RelativeVariable(arrayName + ".array.lo", addr, b, zeropage = false, None, isVolatile = false), stmt.position) - } else { - addThing(ConstantThing(arrayName, a, p), stmt.position) - addThing(ConstantThing(arrayName + ".hi", a.hiByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".lo", a.loByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".array.hi", a.hiByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".array.lo", a.loByte.quickSimplify, b), stmt.position) - } - if (length < 256) { - addThing(ConstantThing(arrayName + ".length", lengthConst, b), stmt.position) - } else { - addThing(ConstantThing(arrayName + ".length", lengthConst, w), stmt.position) - } - if (length > 0 && indexType.isArithmetic) { - if (length <= 256) { - addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 1), b), stmt.position) - } else { - addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 2), w), stmt.position) - } - } + registerArrayAddresses(arrayName, stmt.bank, a, indexType, e, length.toInt, alignment, stmt.position) case _ => log.error(s"Array `${stmt.name}` has weird length", stmt.position) } @@ -2012,41 +1985,56 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error(s"Preinitialized writable array `${stmt.name}` has to be in the default segment.", stmt.position) } addThing(array, stmt.position) - registerAddressConstant(UninitializedMemoryVariable(arrayName, p, VariableAllocationMethod.None, - declaredBank = stmt.bank, Set.empty, alignment, isVolatile = false), stmt.position, options, Some(e)) val a = address match { case None => array.toAddress case Some(aa) => aa } - addThing(RelativeVariable(arrayName + ".first", a, e, zeropage = false, - declaredBank = stmt.bank, isVolatile = false), stmt.position) - if (options.flag(CompilationFlag.LUnixRelocatableCode)) { - val b = get[Type]("byte") - val w = get[Type]("word") - val relocatable = UninitializedMemoryVariable(arrayName, w, VariableAllocationMethod.Static, None, Set.empty, NoAlignment, isVolatile = false) - val addr = relocatable.toAddress - addThing(relocatable, stmt.position) - addThing(RelativeVariable(arrayName + ".array.hi", addr + 1, b, zeropage = false, None, isVolatile = false), stmt.position) - addThing(RelativeVariable(arrayName + ".array.lo", addr, b, zeropage = false, None, isVolatile = false), stmt.position) - } else { - addThing(ConstantThing(arrayName, a, p), stmt.position) - addThing(ConstantThing(arrayName + ".hi", a.hiByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".lo", a.loByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".array.hi", a.hiByte.quickSimplify, b), stmt.position) - addThing(ConstantThing(arrayName + ".array.lo", a.loByte.quickSimplify, b), stmt.position) - } - if (length < 256) { - addThing(ConstantThing(arrayName + ".length", NumericConstant(length, 1), b), stmt.position) - } else { - addThing(ConstantThing(arrayName + ".length", NumericConstant(length, 2), w), stmt.position) - } - if (length > 0 && indexType.isArithmetic) { - if (length <= 256) { - addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 1), b), stmt.position) - } else { - addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 2), w), stmt.position) - } - } + registerArrayAddresses(arrayName, stmt.bank, a, indexType, e, length, alignment, stmt.position) + } + } + + def registerArrayAddresses( + arrayName: String, + declaredBank: Option[String], + address: Constant, + indexType: Type, + elementType: Type, + length: Int, + alignment: MemoryAlignment, + position: Option[Position]): Unit = { + val p = get[Type]("pointer") + val b = get[Type]("byte") + val w = get[Type]("word") + registerAddressConstant(UninitializedMemoryVariable(arrayName, p, VariableAllocationMethod.None, + declaredBank = declaredBank, Set.empty, alignment, isVolatile = false), position, options, Some(elementType)) + addThing(RelativeVariable(arrayName + ".first", address, elementType, zeropage = false, + declaredBank = declaredBank, isVolatile = false), position) + if (options.flag(CompilationFlag.LUnixRelocatableCode)) { + val b = get[Type]("byte") + val w = get[Type]("word") + val relocatable = UninitializedMemoryVariable(arrayName, w, VariableAllocationMethod.Static, None, Set.empty, NoAlignment, isVolatile = false) + val addr = relocatable.toAddress + addThing(relocatable, position) + addThing(RelativeVariable(arrayName + ".array.hi", addr + 1, b, zeropage = false, None, isVolatile = false), position) + addThing(RelativeVariable(arrayName + ".array.lo", addr, b, zeropage = false, None, isVolatile = false), position) + } else { + addThing(ConstantThing(arrayName, address, p), position) + addThing(ConstantThing(arrayName + ".hi", address.hiByte.quickSimplify, b), position) + addThing(ConstantThing(arrayName + ".lo", address.loByte.quickSimplify, b), position) + addThing(ConstantThing(arrayName + ".array.hi", address.hiByte.quickSimplify, b), position) + addThing(ConstantThing(arrayName + ".array.lo", address.loByte.quickSimplify, b), position) + } + if (length < 256) { + addThing(ConstantThing(arrayName + ".length", NumericConstant(length, 1), b), position) + } else { + addThing(ConstantThing(arrayName + ".length", NumericConstant(length, 2), w), position) + } + if (length > 0 && indexType.isArithmetic) { + if (length <= 256) { + addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 1), b), position) + } else { + addThing(ConstantThing(arrayName + ".lastindex", NumericConstant(length - 1, 2), w), position) + } } } @@ -2092,8 +2080,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (constantValue.requiredSize > typ.size) log.error(s"`$name` is has an invalid value: not in the range of `$typ`", position) addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position) for(Subvariable(suffix, offset, t, arraySize) <- getSubvariables(typ)) { - if (arraySize.isDefined) ??? // TODO - addThing(ConstantThing(prefix + name + suffix, constantValue.subconstant(options, offset, t.size), t), stmt.position) + if (arraySize.isDefined) { + log.error(s"Constants of type ${t.name} that contains array fields are not supported", stmt.position) + } else { + addThing(ConstantThing(prefix + name + suffix, constantValue.subconstant(options, offset, t.size), t), stmt.position) + } } } else { if (stmt.stack && stmt.global) log.error(s"`$name` is static or global and cannot be on stack", position) @@ -2165,39 +2156,44 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } def addVariable(options: CompilationOptions, localName: String, variable: Variable, position: Option[Position]): Unit = { + val b = get[VariableType]("byte") variable match { case v: StackVariable => addThing(localName, v, position) for (Subvariable(suffix, offset, t, arraySize) <- getSubvariables(v.typ)) { if (arraySize.isDefined) { - log.error(s"Cannot create a stack variable $localName of compound type ${v.typ.name} that contains an array member") + log.error(s"Cannot create a stack variable $localName of compound type ${v.typ.name} that contains an array member", position) } else { addThing(StackVariable(prefix + localName + suffix, t, baseStackOffset + offset), position) } } case v: MemoryVariable => addThing(localName, v, position) - for (Subvariable(suffix, offset, t, arraySize) <- getSubvariables(v.typ)) { - arraySize match { + for (Subvariable(suffix, offset, t, arrayIndexTypeAndSize) <- getSubvariables(v.typ)) { + arrayIndexTypeAndSize match { case None => val subv = RelativeVariable(prefix + localName + suffix, v.toAddress + offset, t, zeropage = v.zeropage, declaredBank = v.declaredBank, isVolatile = v.isVolatile) addThing(subv, position) registerAddressConstant(subv, position, options, Some(t)) - case Some(_) => - ??? // TODO + case Some((indexType, elemCount)) => + val suba = RelativeArray(prefix + localName + suffix + ".array", v.toAddress + offset, elemCount, v.declaredBank, indexType, t, false) + addThing(suba, position) + registerArrayAddresses(prefix + localName + suffix, v.declaredBank, v.toAddress + offset, indexType, t, elemCount, NoAlignment, position) } } case v: VariableInMemory => addThing(localName, v, position) addThing(ConstantThing(v.name + "`", v.toAddress, get[Type]("word")), position) - for (Subvariable(suffix, offset, t, arraySize) <- getSubvariables(v.typ)) { - arraySize match { + for (Subvariable(suffix, offset, t, arrayIndexTypeAndSize) <- getSubvariables(v.typ)) { + arrayIndexTypeAndSize match { case None => val subv = RelativeVariable(prefix + localName + suffix, v.toAddress + offset, t, zeropage = v.zeropage, declaredBank = v.declaredBank, isVolatile = v.isVolatile) addThing(subv, position) registerAddressConstant(subv, position, options, Some(t)) - case Some(_) => - ??? // TODO + case Some((indexType, elemCount)) => + val suba = RelativeArray(prefix + localName + suffix + ".array", v.toAddress + offset, elemCount, v.declaredBank, indexType, t, false) + addThing(suba, position) + registerArrayAddresses(prefix + localName + suffix, v.declaredBank, v.toAddress + offset, indexType, t, elemCount, NoAlignment, position) } } case _ => ??? @@ -2307,32 +2303,29 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case s: StructType => val builder = new ListBuffer[Subvariable] var offset = 0 - for(ResolvedFieldDesc(typ, fieldName, arraySize) <- s.mutableFieldsWithTypes) { + for(ResolvedFieldDesc(typ, fieldName, indexTypeAndCount) <- s.mutableFieldsWithTypes) { offset = getTypeAlignment(typ, Set()).roundSizeUp(offset) - arraySize match { - case None => - val suffix = "." + fieldName - builder += Subvariable(suffix, offset, typ, arraySize) - if (arraySize.isEmpty) { - builder ++= getSubvariables(typ).map { - case Subvariable(innerSuffix, innerOffset, innerType, innerSize) => Subvariable(suffix + innerSuffix, offset + innerOffset, innerType, innerSize) - } - } - case Some(_) => - // TODO + val suffix = "." + fieldName + builder += Subvariable(suffix, offset, typ, indexTypeAndCount) + if (indexTypeAndCount.isEmpty) { + builder ++= getSubvariables(typ).map { + case Subvariable(innerSuffix, innerOffset, innerType, innerSize) => Subvariable(suffix + innerSuffix, offset + innerOffset, innerType, innerSize) + } } - offset += typ.size * arraySize.getOrElse(1) + offset += typ.size * indexTypeAndCount.fold(1)(_._2) offset = getTypeAlignment(typ, Set()).roundSizeUp(offset) } builder.toList case s: UnionType => val builder = new ListBuffer[Subvariable] - for(FieldDesc(typeName, fieldName, _) <- s.fields) { + for(FieldDesc(typeName, fieldName, arraySize) <- s.fields) { val typ = get[VariableType](typeName) val suffix = "." + fieldName builder += Subvariable(suffix, 0, typ) - builder ++= getSubvariables(typ).map { - case Subvariable(innerSuffix, innerOffset, innerType, innerSize) => Subvariable(suffix + innerSuffix, innerOffset, innerType, innerSize) + if (arraySize.isEmpty) { + builder ++= getSubvariables(typ).map { + case Subvariable(innerSuffix, innerOffset, innerType, innerSize) => Subvariable(suffix + innerSuffix, innerOffset, innerType, innerSize) + } } } builder.toList @@ -2476,30 +2469,50 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + def getArrayFieldIndexTypeAndSize(expr: Expression): (VariableType, Int) = { + val b = get[VariableType]("byte") + expr match { + case VariableExpression(name) => + maybeGet[Type](name) match { + case Some(typ@EnumType(_, Some(count))) => + return typ -> count + case Some(typ) => + log.error(s"Type $name cannot be used as an array index", expr.position) + return b -> 0 + case _ => + } + case _ => + } + val constant: Int = eval(expr).map(_.quickSimplify) match { + case Some(NumericConstant(n, _)) if n >= 0 && n <= 127 => + n.toInt + case Some(NumericConstant(n, _)) => + log.error(s"Array size too large", expr.position) + 1 + case Some(_) => + log.error(s"Array size cannot be fully resolved", expr.position) + 1 + case _ => + errorConstant(s"Array has non-constant length", Some(expr), expr.position) + 1 + } + if (constant <= 256) { + b -> constant + } else { + get[VariableType]("word") -> constant + } + } + def fixStructFields(): Unit = { // TODO: handle arrays? things.values.foreach { case st@StructType(_, fields, _) => st.mutableFieldsWithTypes = fields.map { - case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map { x => - eval(x) match { - case Some(NumericConstant(c, _)) if c >= 0 && c < 0x10000 => c.toInt - case _ => - log.error(s"Invalid array size for member array $name in type ${st.toString}") - 0 - } - }) + case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map(getArrayFieldIndexTypeAndSize)) } case ut@UnionType(_, fields, _) => ut.mutableFieldsWithTypes = fields.map { - case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map { x => - eval(x) match { - case Some(NumericConstant(c, _)) if c >= 0 && c < 0x10000 => c.toInt - case _ => - log.error(s"Invalid array size for member array $name in type ${ut.toString}") - 0 - } - }) + case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map(getArrayFieldIndexTypeAndSize)) } case _ => () } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index a7468e95..a9941c52 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -55,7 +55,7 @@ sealed trait VariableType extends Type { } -case class Subvariable(suffix: String, offset: Int, typ: VariableType, arraySize: Option[Int] = None) +case class Subvariable(suffix: String, offset: Int, typ: VariableType, arrayIndexTypeAndSize: Option[(VariableType, Int)] = None) case object VoidType extends Type { def size = 0 diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index ad2e019e..e702303d 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -5,14 +5,14 @@ import millfork.assembly.m6809.{MAddrMode, MOpcode} import millfork.assembly.mos.opt.SourceOfNZ import millfork.assembly.mos.{AddrMode, Opcode} import millfork.assembly.z80.{NoRegisters, OneRegister, ZOpcode, ZRegisters} -import millfork.env.{Constant, ParamPassingConvention, Type, VariableType} +import millfork.env.{Constant, EnumType, ParamPassingConvention, Type, VariableType} import millfork.output.MemoryAlignment case class Position(moduleName: String, line: Int, column: Int, cursor: Int) case class FieldDesc(typeName:String, fieldName: String, arraySize: Option[Expression]) -case class ResolvedFieldDesc(typ:VariableType, fieldName: String, arraySize: Option[Int]) +case class ResolvedFieldDesc(typ:VariableType, fieldName: String, arrayIndexTypeAndSize: Option[(VariableType, Int)]) sealed trait Node { var position: Option[Position] = None diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 5a95d68a..0a42e07e 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -775,42 +775,46 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri variants <- enumVariants ~/ Pass } yield Seq(EnumDefinitionStatement(name, variants).pos(p)) - val compoundTypeField: P[FieldDesc] = ("array".! ~ HWS ~/ Pass).?.flatMap { + val compoundTypeField: P[FieldDesc] = ("array".! ~ !letterOrDigit ~ HWS ~/ Pass).?.flatMap { case None => - for { - typ <- identifier ~/ HWS - name <- identifier ~ HWS - } yield FieldDesc(typ, name, None) + (identifier ~/ HWS ~ identifier ~/ HWS).map { + case (typ, name) => FieldDesc(typ, name, None) + } case Some(_) => - for { - elementType <- ("(" ~/ AWS ~/ identifier ~ AWS ~ ")").? ~/ HWS - if enableDebuggingOptions - name <- identifier ~ HWS - length <- "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~ "]" ~ HWS - } yield FieldDesc(elementType.getOrElse("byte"), name, Some(length)) + (("(" ~/ AWS ~/ identifier ~ AWS ~ ")").? ~/ HWS ~/ + identifier ~ HWS ~ + "[" ~/ AWS ~/ mfExpression(nonStatementLevel, false) ~ AWS ~ "]" ~/ HWS + ).map{ + case (elementType, name, length) => + FieldDesc(elementType.getOrElse("byte"), name, Some(length)) + } } val compoundTypeFields: P[List[FieldDesc]] = ("{" ~/ AWS ~ compoundTypeField.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) - val structDefinition: P[Seq[StructDefinitionStatement]] = for { - p <- position() - _ <- "struct" ~ !letterOrDigit ~/ SWS ~ position("struct name") - name <- identifier ~/ HWS - align <- alignmentDeclaration(NoAlignment).? ~/ HWS - _ <- position("struct definition block") - fields <- compoundTypeFields ~/ Pass - } yield Seq(StructDefinitionStatement(name, fields, align).pos(p)) + val structDefinition: P[Seq[StructDefinitionStatement]] = { + (position() ~ "struct" ~ !letterOrDigit ~/ SWS ~/ + position("struct name").map(_ => ()) ~ identifier ~/ HWS ~ + alignmentDeclaration(NoAlignment).? ~/ HWS ~ + position("struct definition block").map(_ => ()) ~ + compoundTypeFields ~/ Pass).map{ + case (p, name, align, fields) => + Seq(StructDefinitionStatement(name, fields, align).pos(p)) + } + } - val unionDefinition: P[Seq[UnionDefinitionStatement]] = for { - p <- position() - _ <- "union" ~ !letterOrDigit ~/ SWS ~ position("union name") - name <- identifier ~/ HWS - align <- alignmentDeclaration(NoAlignment).? ~/ HWS - _ <- position("union definition block") - fields <- compoundTypeFields ~/ Pass - } yield Seq(UnionDefinitionStatement(name, fields, align).pos(p)) + val unionDefinition: P[Seq[UnionDefinitionStatement]] = { + (position() ~ "union" ~ !letterOrDigit ~/ SWS ~/ + position("union name").map(_ => ()) ~ identifier ~/ HWS ~ + alignmentDeclaration(NoAlignment).? ~/ HWS ~ + position("union definition block").map(_ => ()) ~ + compoundTypeFields ~/ Pass).map{ + case (p, name, align, fields) => + Seq(UnionDefinitionStatement(name, fields, align).pos(p)) + } + } val segmentBlock: P[Seq[BankedDeclarationStatement]] = for { (_, bankName) <- "segment" ~ AWS ~ "(" ~ AWS ~ position("segment name") ~ identifier ~ AWS ~ ")" ~ AWS ~ "{" ~/ AWS diff --git a/src/test/scala/millfork/test/StructSuite.scala b/src/test/scala/millfork/test/StructSuite.scala index 5485c30c..95328481 100644 --- a/src/test/scala/millfork/test/StructSuite.scala +++ b/src/test/scala/millfork/test/StructSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun} +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile, ShouldNotParse} import org.scalatest.{FunSuite, Matchers} /** @@ -226,4 +226,124 @@ class StructSuite extends FunSuite with Matchers { m.readByte(0xc001) should equal(2) } } + + test("Array struct fields") { + val code = + """ + |import zp_reg + |struct S { + | array tmp[8] + |} + | + |S output @$c000 + | + |array(S) outputAlias [1] @$c000 + | + |noinline byte id(byte x) = x + |noinline void dontOptimize(pointer.S dummy) {} + |void main() { + | output.tmp[0] = 1 + | output.tmp[4] = 4 + | pointer.S p + | p = output.pointer + | p->tmp[1] = 77 + | outputAlias[0].tmp[id(3)] = 3 + | outputAlias[id(0)].tmp[5] = 55 + |} + |""".stripMargin + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)(code){ m => + m.readByte(0xc000) should equal(1) + m.readByte(0xc001) should equal(77) + m.readByte(0xc003) should equal(3) + m.readByte(0xc004) should equal(4) + m.readByte(0xc005) should equal(55) + } + } + + test("Struct layout with array fields") { + val code = + """ + |struct S { + | array (word) a[4] + | byte x + |} + | + |array outputs [10] @$c000 + | + |void main() { + | S tmp + | outputs[1] = sizeof(S) + | outputs[2] = S.a.offset + | outputs[4] = S.x.offset + | outputs[5] = lo(tmp.a[1].addr - tmp.a[0].addr) + | outputs[6] = lo(tmp.x.addr - tmp.a.addr) + |} + |""".stripMargin + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)(code){ m => + m.readByte(0xc001) should equal(9) + m.readByte(0xc002) should equal(0) + m.readByte(0xc004) should equal(8) + m.readByte(0xc005) should equal(2) + m.readByte(0xc006) should equal(8) + } + } + + test("Structs with enum-indexed array fields") { + val code = + """ + | enum Suit { + | Hearts, Diamonds, Clubs, Spades + | } + | struct Deck { + | array(byte) count[Suit] + | } + | + | array output[5] @$c000 + | void main() { + | Deck d + | output[0] = d.count.length + | output[1] = sizeof(Deck) + | d.count[Diamonds] = 5 + | d.count[Clubs] = d.count[Diamonds] + | output[2] = d.count[Clubs] + | } + |""".stripMargin + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)(code){ m => + m.readByte(0xc000) should equal(4) + m.readByte(0xc001) should equal(4) + m.readByte(0xc002) should equal(5) + } + } + + test("Structs with array fields – invalid uses") { + ShouldNotCompile( + """ + | struct S { + | array a[4] + | } + | + | void main() { + | stack S s + | } + |""".stripMargin) + ShouldNotCompile( + """ + | struct S { + | array a[4] + | } + | void main() { + | S s + | s = S(4) + | } + |""".stripMargin) + } + + test("Structs with array fields – invalid syntax") { + ShouldNotParse( + """ + | struct S { + | byte a[4] + | } + |""".stripMargin) + } }