diff --git a/docs/lang/types.md b/docs/lang/types.md index 3f365206..02b42c3a 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -224,7 +224,7 @@ as there are no checks on values when converting bytes to enumeration values and Struct is a compound type containing multiple fields of various types: - struct <name> { <field definitions (type and name), separated by commas or newlines>} + struct <name> [align (alignment)] { <field definitions (type and name), separated by commas or newlines>} A struct is represented in memory as a contiguous area of variables laid out one after another. @@ -254,9 +254,27 @@ You can create constant expressions of struct types using so-called struct const All arguments to the constructor must be constant. +Structures declared with an alignment are allocated at appropriate memory addresses. +The alignment has to be a power of two. +If the structs are in an array, they are padded with unused bytes. +If the struct is smaller that its alignment, then arrays of it are faster + + struct a align(4) { byte x,byte y, byte z } + struct b { byte x,byte y, byte z } + array(a) as [4] @ $C000 + array(b) bs [4] @ $C800 + + a[1].addr - a[0].addr // equals 4 + b[1].addr - b[0].addr // equals 3 + sizeof(a) // equals 16 + sizeof(b) // equals 12 + + return a[i].x // requires XXXX cycles on 6502 + return b[i].x // requires XXXX cycles on 6502 + ## Unions - union <name> { <field definitions (type and name), separated by commas or newlines>} + union <name> [align (alignment)] { <field definitions (type and name), separated by commas or newlines>} Unions are pretty similar to structs, with the difference that all fields of the union start at the same point in memory and therefore overlap each other. diff --git a/src/main/scala/millfork/MathUtils.scala b/src/main/scala/millfork/MathUtils.scala new file mode 100644 index 00000000..946c0840 --- /dev/null +++ b/src/main/scala/millfork/MathUtils.scala @@ -0,0 +1,15 @@ +package millfork + +/** + * @author Karol Stasiak + */ +object MathUtils { + + @scala.annotation.tailrec + def gcd(a: Int, b: Int): Int = + if (b == 0) a else gcd(b, a % b) + + def lcm(a: Int, b: Int): Int = + (a.toLong & b.toLong / gcd(a, b)).toInt + +} diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index bb833fd7..728f066d 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -64,7 +64,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte } def isWordPointy(name: String): Boolean = { - env.getPointy(name).elementType.size == 2 + env.getPointy(name).elementType.alignedSize == 2 } def optimizeStmt(stmt: ExecutableStatement, currentVarValues: VV): (ExecutableStatement, VV) = { @@ -409,32 +409,39 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte } ctx.log.trace(s"$result is $x and targets $targetType") env.eval(index) match { - case Some(NumericConstant(n, _)) if n >= 0 && (targetType.size * n) <= 127 => + case Some(NumericConstant(n, _)) if n >= 0 && (targetType.alignedSize * n) <= 127 => x match { case _: PointerType => - DerefExpression(result, targetType.size * n.toInt, targetType) + DerefExpression(result, targetType.alignedSize * n.toInt, targetType) case _ => DerefExpression( ("pointer." + targetType.name) <| result, - targetType.size * n.toInt, targetType) + targetType.alignedSize * n.toInt, targetType) } case _ => + val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize) + val shrunkElementSize = targetType.alignedSize >> shifts + val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts)) val scaledIndex = arraySizeInBytes match { - case Some(n) if n <= 256 => targetType.size match { + case Some(n) if n <= 256 => targetType.alignedSize match { case 1 => "byte" <| index case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1)) case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1)) case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.size, 1)) + case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1)) } - case Some(n) if n <= 512 && targetType.size == 2 => + case Some(n) if n <= 512 && targetType.alignedSize == 2 => "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) - case _ => targetType.size match { + case Some(n) if n <= 512 && targetType.alignedSize == 2 => + "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) + case Some(_) if shrunkArraySize <= 256 => + "<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1)) + case _ => targetType.alignedSize match { case 1 => "word" <| index case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1)) case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1)) case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("word" <| index, LiteralExpression(targetType.size, 1)) + case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1)) } } // TODO: re-cast pointer type @@ -618,13 +625,13 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte case IndexedExpression(name, index) => val pointy = env.getPointy(name) val targetType = pointy.elementType - targetType.size match { + targetType.alignedSize match { case 1 => IndexedExpression(name, optimizeExpr(index, Map())).pos(pos) case _ => val constantOffset: Option[Long] = env.eval(index) match { case Some(z) if z.isProvablyZero => Some(0L) case Some(NumericConstant(n, _)) => - if (targetType.size * (n+1) <= 256) Some(targetType.size * n) else None + if (targetType.alignedSize * (n+1) <= 256) Some(targetType.alignedSize * n) else None case _ => None } constantOffset match { @@ -643,22 +650,27 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte case p: ConstantPointy => p.sizeInBytes case _ => None } + val shifts = Integer.numberOfTrailingZeros(targetType.alignedSize) + val shrunkElementSize = targetType.alignedSize >> shifts + val shrunkArraySize = arraySizeInBytes.fold(9999)(_.>>(shifts)) val scaledIndex = arraySizeInBytes match { - case Some(n) if n <= 256 => targetType.size match { + case Some(n) if n <= 256 => targetType.alignedSize match { case 1 => "byte" <| index case 2 => "<<" <| ("byte" <| index, LiteralExpression(1, 1)) case 4 => "<<" <| ("byte" <| index, LiteralExpression(2, 1)) case 8 => "<<" <| ("byte" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.size, 1)) + case _ => "*" <| ("byte" <| index, LiteralExpression(targetType.alignedSize, 1)) } - case Some(n) if n <= 512 && targetType.size == 2 => + case Some(n) if n <= 512 && targetType.alignedSize == 2 => "nonet" <| ("<<" <| ("byte" <| index, LiteralExpression(1, 1))) - case _ => targetType.size match { + case Some(_) if shrunkArraySize <= 256 => + "<<" <| ("word" <| ("*" <| ("byte" <| index, LiteralExpression(shrunkElementSize, 1))), LiteralExpression(shifts, 1)) + case _ => targetType.alignedSize match { case 1 => "word" <| index case 2 => "<<" <| ("word" <| index, LiteralExpression(1, 1)) case 4 => "<<" <| ("word" <| index, LiteralExpression(2, 1)) case 8 => "<<" <| ("word" <| index, LiteralExpression(3, 1)) - case _ => "*" <| ("word" <| index, LiteralExpression(targetType.size, 1)) + case _ => "*" <| ("word" <| index, LiteralExpression(targetType.alignedSize, 1)) } } DerefExpression( diff --git a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala index 42978282..d9460923 100644 --- a/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/m6809/M6809ExpressionCompiler.scala @@ -142,7 +142,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { } case v:VariablePointy => val (prepareIndex, offset): (List[MLine], Constant) = ctx.env.eval(index) match { - case Some(ix) => List(MLine.absolute(LDX, v.addr)) -> (ix * v.elementType.size).quickSimplify + case Some(ix) => List(MLine.absolute(LDX, v.addr)) -> (ix * v.elementType.alignedSize).quickSimplify case _ => v.indexType.size match { case 1 | 2 => @@ -167,7 +167,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { } case v:StackVariablePointy => ctx.env.eval(index) match { - case Some(ix) => List(MLine.variablestack(ctx, LDX, v.offset), MLine.indexedX(LDB, ix * v.elementType.size)) + case Some(ix) => List(MLine.variablestack(ctx, LDX, v.offset), MLine.indexedX(LDB, ix * v.elementType.alignedSize)) } } case e@SumExpression(expressions, decimal) => @@ -739,7 +739,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { ctx.env.eval(index) match { case Some(ix) => if (ix.isProvablyZero) List(MLine(store, Absolute(true), v.addr)) - else List(MLine.absolute(LDX, v.addr), MLine.indexedX(store, ix * v.elementType.size)) + else List(MLine.absolute(LDX, v.addr), MLine.indexedX(store, ix * v.elementType.alignedSize)) case _ => v.indexType.size match { case 1 | 2 => @@ -752,7 +752,7 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { } case v: StackVariablePointy => ctx.env.eval(index) match { - case Some(ix) => List(MLine.variablestack(ctx, LDX, v.offset), MLine.indexedX(store, ix * v.elementType.size)) + case Some(ix) => List(MLine.variablestack(ctx, LDX, v.offset), MLine.indexedX(store, ix * v.elementType.alignedSize)) } } } @@ -827,14 +827,14 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] { compileToX(ctx, inner) :+ MLine.indexedX(MOpcode.LEAX, Constant(offset)) case IndexedExpression(aname, index) => ctx.env.getPointy(aname) match { - case p: VariablePointy => compileToD(ctx, index #*# p.elementType.size) ++ List(MLine.absolute(ADDD, p.addr), MLine.tfr(M6809Register.D, M6809Register.X)) + case p: VariablePointy => compileToD(ctx, index #*# p.elementType.alignedSize) ++ List(MLine.absolute(ADDD, p.addr), MLine.tfr(M6809Register.D, M6809Register.X)) case p: ConstantPointy => if (p.sizeInBytes.exists(_ < 255)) { - compileToB(ctx, index #*# p.elementType.size) ++ List(MLine.immediate(LDX, p.value), MLine.inherent(ABX)) + compileToB(ctx, index #*# p.elementType.alignedSize) ++ List(MLine.immediate(LDX, p.value), MLine.inherent(ABX)) } else { - compileToX(ctx, index #*# p.elementType.size) :+ MLine.indexedX(LEAX, p.value) + compileToX(ctx, index #*# p.elementType.alignedSize) :+ MLine.indexedX(LEAX, p.value) } - case p:StackVariablePointy => compileToD(ctx, index #*# p.elementType.size) ++ List(MLine.variablestack(ctx, ADDD, p.offset), MLine.tfr(M6809Register.D, M6809Register.X)) + case p:StackVariablePointy => compileToD(ctx, index #*# p.elementType.alignedSize) ++ List(MLine.variablestack(ctx, ADDD, p.offset), MLine.tfr(M6809Register.D, M6809Register.X)) } case _ => ??? } diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index db46bc11..f543f64f 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -1309,7 +1309,6 @@ object BuiltIns { val adc = addingCode.last val indexing = addingCode.init result ++= indexing - result += AssemblyLine.immediate(LDA, 0) val mult = c & 0xff var mask = 128 var empty = true @@ -1318,8 +1317,12 @@ object BuiltIns { result += AssemblyLine.implied(ASL) } if ((mult & mask) != 0) { + if (empty) { + result += adc.copy(opcode = LDA) + } else { result += AssemblyLine.implied(CLC) result += adc + } empty = false } diff --git a/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala b/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala index d092457c..60999c98 100644 --- a/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala +++ b/src/main/scala/millfork/compiler/mos/MosBulkMemoryOperations.scala @@ -19,7 +19,7 @@ object MosBulkMemoryOperations { return MosStatementCompiler.compileForStatement(ctx, f)._1 } val pointy = ctx.env.getPointy(target.name) - if (pointy.elementType.size != 1) { + if (pointy.elementType.alignedSize != 1) { return MosStatementCompiler.compileForStatement(ctx, f)._1 } val w = ctx.env.get[Type]("word") diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 12e8c302..e8f5fc26 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -967,7 +967,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case IndexedExpression(arrayName, indexExpr) => val pointy = env.getPointy(arrayName) AbstractExpressionCompiler.checkIndexType(ctx, pointy, indexExpr) - if (pointy.elementType.size != 1) ctx.log.fatal("Whee!") // the statement preprocessor should have removed all of those + if (pointy.elementType.alignedSize != 1) ctx.log.fatal("Whee!") // the statement preprocessor should have removed all of those // TODO: check val (variableIndex, constantIndex) = env.evalVariableAndConstantSubParts(indexExpr) val variableIndexSize = variableIndex.map(v => getExpressionType(ctx, v).size).getOrElse(0) diff --git a/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala index 72bbaad2..4457be62 100644 --- a/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala +++ b/src/main/scala/millfork/compiler/z80/Z80BulkMemoryOperations.scala @@ -435,7 +435,7 @@ object Z80BulkMemoryOperations { loadA: ZOpcode.Value => List[ZLine], z80Bulk: Boolean => Option[ZOpcode.Value]): List[ZLine] = { val pointy = ctx.env.getPointy(target.name) - if (pointy.elementType.size > 1) return Z80StatementCompiler.compileForStatement(ctx, f)._1 + if (pointy.elementType.alignedSize > 1) return Z80StatementCompiler.compileForStatement(ctx, f)._1 val targetOffset = removeVariableOnce(ctx, f.variable, target.index).getOrElse(return compileForStatement(ctx, f)._1) if (!targetOffset.isPure) return compileForStatement(ctx, f)._1 val indexVariableSize = ctx.env.get[Variable](f.variable).typ.size diff --git a/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala b/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala index 8b6926ee..0dbebf43 100644 --- a/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/z80/Z80StatementPreprocessor.scala @@ -29,7 +29,7 @@ class Z80StatementPreprocessor(ctx: CompilationContext, statements: List[Executa case f: DerefDebuggingExpression => Nil case IndexedExpression(a, VariableExpression(v)) => if (v == variable) { ctx.env.maybeGet[Thing](a + ".array") match { - case Some(array: MfArray) if array.elementType.size == 1 => Seq(a) + case Some(array: MfArray) if array.elementType.alignedSize == 1 => Seq(a) case _ => Nil } } else Nil diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index dc391990..6b00e54b 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -384,9 +384,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa InitializedMemoryVariable UninitializedMemoryVariable getArrayOrPointer(name) match { - case th@InitializedArray(_, _, cs, _, i, e, ro, _) => ConstantPointy(th.toAddress, Some(name), Some(e.size * cs.length), Some(cs.length), i, e, th.alignment, readOnly = ro) - case th@UninitializedArray(_, elementCount, _, i, e, ro, _) => ConstantPointy(th.toAddress, Some(name), Some(elementCount * e.size), Some(elementCount / e.size), i, e, th.alignment, readOnly = ro) - case th@RelativeArray(_, _, elementCount, _, i, e, ro) => ConstantPointy(th.toAddress, Some(name), Some(elementCount * e.size), Some(elementCount / e.size), i, e, NoAlignment, readOnly = ro) + case th@InitializedArray(_, _, cs, _, i, e, ro, _) => ConstantPointy(th.toAddress, Some(name), Some(e.alignedSize * cs.length), Some(cs.length), i, e, th.alignment, readOnly = ro) + case th@UninitializedArray(_, elementCount, _, i, e, ro, _) => ConstantPointy(th.toAddress, Some(name), Some(elementCount * e.alignedSize), Some(elementCount / e.size), i, e, th.alignment, readOnly = ro) + case th@RelativeArray(_, _, elementCount, _, i, e, ro) => ConstantPointy(th.toAddress, Some(name), Some(elementCount * e.alignedSize), Some(elementCount / e.size), i, e, NoAlignment, readOnly = ro) case ConstantThing(_, value, typ) if typ.size <= 2 && typ.isPointy => val e = get[VariableType](typ.pointerTargetName) val w = get[VariableType]("word") @@ -535,8 +535,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa BranchingOpcodeMapping(Opcode.BPL, IfFlagClear(ZFlag.S), MOpcode.BPL), BranchingOpcodeMapping(Opcode.BMI, IfFlagSet(ZFlag.S), MOpcode.BMI)), None) - val byte_and_pointer$ = StructType("byte_and_pointer$", List(FieldDesc("byte", "zp", None), FieldDesc("pointer", "branch", None))) - val hudson_transfer$ = StructType("hudson_transfer$", List(FieldDesc("word", "a", None), FieldDesc("word", "b", None), FieldDesc("word", "c", None))) + val byte_and_pointer$ = StructType("byte_and_pointer$", List(FieldDesc("byte", "zp", None), FieldDesc("pointer", "branch", None)), NoAlignment) + val hudson_transfer$ = StructType("hudson_transfer$", List(FieldDesc("word", "a", None), FieldDesc("word", "b", None), FieldDesc("word", "c", None)), NoAlignment) addThing(byte_and_pointer$, None) addThing(hudson_transfer$, None) Environment.constOnlyBuiltinFunction.foreach(n => addThing(ConstOnlyCallable(n), None)) @@ -642,8 +642,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa hintTypo(name) 1 case Some(thing) => thing match { - case t: Type => t.size - case v: Variable => v.typ.size + case t: Type => t.alignedSize + case v: Variable => v.typ.alignedSize case a: MfArray => a.sizeInBytes case ConstantThing(_, MemoryAddressConstant(a: MfArray), _) => a.sizeInBytes case x => @@ -652,7 +652,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } case _ => - AbstractExpressionCompiler.getExpressionType(this, log, expr).size + AbstractExpressionCompiler.getExpressionType(this, log, expr).alignedSize } NumericConstant(size, Constant.minimumSize(size)) } @@ -675,7 +675,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } case IndexedExpression(arrName, index) => getPointy(arrName) match { - case ConstantPointy(MemoryAddressConstant(arr:InitializedArray), _, _, _, _, _, _, _) if arr.readOnly && arr.elementType.size == 1 => + case ConstantPointy(MemoryAddressConstant(arr:InitializedArray), _, _, _, _, _, _, _) if arr.readOnly && arr.elementType.alignedSize == 1 => evalImpl(index, vv).flatMap { case NumericConstant(constIndex, _) => if (constIndex >= 0 && constIndex < arr.sizeInBytes) { @@ -1062,7 +1062,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error(s"Invalid field name: `${f.fieldName}`", stmt.position) } } - addThing(StructType(stmt.name, stmt.fields), stmt.position) + addThing(StructType(stmt.name, stmt.fields, stmt.alignment.getOrElse(NoAlignment)), stmt.position) } def registerUnion(stmt: UnionDefinitionStatement): Unit = { @@ -1071,7 +1071,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error(s"Invalid field name: `${f.fieldName}`", stmt.position) } } - addThing(UnionType(stmt.name, stmt.fields), stmt.position) + addThing(UnionType(stmt.name, stmt.fields, stmt.alignment.getOrElse(NoAlignment)), stmt.position) } def getTypeSize(t: VariableType, path: Set[String]): Int = { @@ -1765,7 +1765,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } val length = contents.length if (length > 0xffff || length < 0) log.error(s"Array `${stmt.name}` has invalid length", stmt.position) - val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length)) + val alignment = stmt.alignment.getOrElse(defaultArrayAlignment(options, length)) & e.alignment val address = stmt.address.map(a => eval(a).getOrElse(errorConstant(s"Array `${stmt.name}` has non-constant address", stmt.position))) for (element <- contents) { AbstractExpressionCompiler.checkAssignmentTypeLoosely(this, element, e) @@ -1825,7 +1825,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val b = get[Type]("byte") val w = get[Type]("word") val typ = get[VariableType](stmt.typ) - val alignment = stmt.alignment.getOrElse(defaultVariableAlignment(options, typ.size)) + val alignment = stmt.alignment.getOrElse(defaultVariableAlignment(options, typ.size)) & typ.alignment if (stmt.constant) { if (stmt.stack) log.error(s"`$name` is a constant and cannot be on stack", position) if (stmt.register) log.error(s"`$name` is a constant and cannot be in a register", position) @@ -2161,8 +2161,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def fixStructSizes(): Unit = { val allStructTypes: Iterable[VariableType] = things.values.flatMap { - case s@StructType(name, _) => Some(s) - case s@UnionType(name, _) => Some(s) + case s@StructType(name, _, _) => Some(s) + case s@UnionType(name, _, _) => Some(s) case _ => None } var iterations = allStructTypes.size @@ -2180,7 +2180,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def fixStructFields(): Unit = { // TODO: handle arrays? things.values.foreach { - case st@StructType(_, fields) => + case st@StructType(_, fields, _) => st.mutableFieldsWithTypes = fields.map { case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map { x => eval(x) match { @@ -2191,7 +2191,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } }) } - case ut@UnionType(_, fields) => + case ut@UnionType(_, fields, _) => ut.mutableFieldsWithTypes = fields.map { case FieldDesc(tn, name, arraySize) => ResolvedFieldDesc(get[VariableType](tn), name, arraySize.map { x => eval(x) match { diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 1fe9caa6..896dc221 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -24,6 +24,10 @@ sealed trait Type extends CallableThing { def size: Int + def alignedSize: Int = alignment.roundSizeUp(size) + + def alignment: MemoryAlignment + def isSigned: Boolean def isBoollike: Boolean = false @@ -55,6 +59,8 @@ case object VoidType extends Type { def isSigned = false override def name = "void" + + override def alignment: MemoryAlignment = NoAlignment } sealed trait PlainType extends VariableType { @@ -73,12 +79,16 @@ case class BasicPlainType(name: String, size: Int) extends PlainType { def isSigned = false override def isSubtypeOf(other: Type): Boolean = this == other + + override def alignment: MemoryAlignment = NoAlignment } case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean, override val isPointy: Boolean) extends PlainType { def size: Int = parent.size override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other) + + override def alignment: MemoryAlignment = parent.alignment } case class PointerType(name: String, targetName: String, var target: Option[Type]) extends VariableType { @@ -89,6 +99,8 @@ case class PointerType(name: String, targetName: String, var target: Option[Type override def isPointy: Boolean = true override def pointerTargetName: String = targetName + + override def alignment: MemoryAlignment = NoAlignment } case class FunctionPointerType(name: String, paramTypeName:String, returnTypeName: String, var paramType: Option[Type], var returnType: Option[Type]) extends VariableType { @@ -97,6 +109,8 @@ case class FunctionPointerType(name: String, paramTypeName:String, returnTypeNam override def isSigned: Boolean = false override def isPointy: Boolean = false + + override def alignment: MemoryAlignment = NoAlignment } case object NullType extends VariableType { @@ -111,24 +125,28 @@ case object NullType extends VariableType { override def isSubtypeOf(other: Type): Boolean = this == other || (other.isPointy || other.isInstanceOf[FunctionPointerType]) && other.size == 2 override def isAssignableTo(targetType: Type): Boolean = this == targetType || (targetType.isPointy || targetType.isInstanceOf[FunctionPointerType]) && targetType.size == 2 + + override def alignment: MemoryAlignment = NoAlignment } case class EnumType(name: String, count: Option[Int]) extends VariableType { override def size: Int = 1 override def isSigned: Boolean = false + + override def alignment: MemoryAlignment = NoAlignment } sealed trait CompoundVariableType extends VariableType -case class StructType(name: String, fields: List[FieldDesc]) extends CompoundVariableType { +case class StructType(name: String, fields: List[FieldDesc], override val alignment: MemoryAlignment) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 var mutableFieldsWithTypes: List[ResolvedFieldDesc] = Nil override def isSigned: Boolean = false } -case class UnionType(name: String, fields: List[FieldDesc]) extends CompoundVariableType { +case class UnionType(name: String, fields: List[FieldDesc], override val alignment: MemoryAlignment) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 var mutableFieldsWithTypes: List[ResolvedFieldDesc] = Nil @@ -151,6 +169,8 @@ case object FatBooleanType extends VariableType { override def isAssignableTo(targetType: Type): Boolean = this == targetType override def isExplicitlyCastableTo(targetType: Type): Boolean = targetType.isArithmetic || isAssignableTo(targetType) + + override def alignment: MemoryAlignment = NoAlignment } sealed trait BooleanType extends Type { @@ -163,6 +183,8 @@ sealed trait BooleanType extends Type { override def isAssignableTo(targetType: Type): Boolean = isCompatible(targetType) || targetType == FatBooleanType override def isExplicitlyCastableTo(targetType: Type): Boolean = targetType.isArithmetic || isAssignableTo(targetType) + + override def alignment: MemoryAlignment = NoAlignment } case class ConstantBooleanType(name: String, value: Boolean) extends BooleanType @@ -252,7 +274,7 @@ case class Placeholder(name: String, typ: Type) extends Variable { } sealed trait UninitializedMemory extends ThingInMemory { - def sizeInBytes: Int + def sizeInBytes: Int def alloc: VariableAllocationMethod.Value @@ -284,7 +306,7 @@ case class UninitializedMemoryVariable( declaredBank: Option[String], override val alignment: MemoryAlignment, override val isVolatile: Boolean) extends MemoryVariable with UninitializedMemory { - override def sizeInBytes: Int = typ.size + override def sizeInBytes: Int = typ.alignedSize override def zeropage: Boolean = alloc == VariableAllocationMethod.Zeropage @@ -328,7 +350,7 @@ case class UninitializedArray(name: String, elementCount: Int, declaredBank: Opt override def zeropage: Boolean = false - override def sizeInBytes: Int = elementCount * elementType.size + override def sizeInBytes: Int = elementCount * elementType.alignedSize } case class RelativeArray(name: String, address: Constant, elementCount: Int, declaredBank: Option[String], indexType: VariableType, elementType: VariableType, override val readOnly: Boolean) extends MfArray { @@ -340,7 +362,7 @@ case class RelativeArray(name: String, address: Constant, elementCount: Int, dec override def zeropage: Boolean = false - override def sizeInBytes: Int = elementCount * elementType.size + override def sizeInBytes: Int = elementCount * elementType.alignedSize override def rootName: String = address.rootThingName } @@ -357,7 +379,7 @@ case class InitializedArray(name: String, address: Option[Constant], contents: S override def elementCount: Int = contents.size - override def sizeInBytes: Int = contents.size * elementType.size + override def sizeInBytes: Int = contents.size * elementType.alignedSize } case class RelativeVariable(name: String, address: Constant, typ: Type, zeropage: Boolean, declaredBank: Option[String], override val isVolatile: Boolean) extends VariableInMemory { diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 66dca333..c5900f0c 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -565,11 +565,11 @@ case class EnumDefinitionStatement(name: String, variants: List[(String, Option[ override def getAllExpressions: List[Expression] = variants.flatMap(_._2) } -case class StructDefinitionStatement(name: String, fields: List[FieldDesc]) extends DeclarationStatement { +case class StructDefinitionStatement(name: String, fields: List[FieldDesc], alignment: Option[MemoryAlignment]) extends DeclarationStatement { override def getAllExpressions: List[Expression] = Nil } -case class UnionDefinitionStatement(name: String, fields: List[FieldDesc]) extends DeclarationStatement { +case class UnionDefinitionStatement(name: String, fields: List[FieldDesc], alignment: Option[MemoryAlignment]) extends DeclarationStatement { override def getAllExpressions: List[Expression] = Nil } diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index 2bf7a4fb..50c4442a 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -359,7 +359,7 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program for (item <- items) { env.eval(item) match { case Some(c) => - for(i <- 0 until elementType.size) { + for(i <- 0 until elementType.alignedSize) { writeByte(bank, index, subbyte(c, i, elementType.size))(None) bank0.occupied(index) = true bank0.initialized(index) = true @@ -544,11 +544,11 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program log.error(s"Preinitialized variable `$name` should be defined in the `default` bank") } val bank0 = mem.banks(bank) - var index = codeAllocators(bank).allocateBytes(bank0, options, typ.size, initialized = true, writeable = true, location = AllocationLocation.High, alignment = alignment) + var index = codeAllocators(bank).allocateBytes(bank0, options, typ.alignedSize, initialized = true, writeable = true, location = AllocationLocation.High, alignment = alignment) labelMap(name) = bank0.index -> index if (!readOnlyPass) { rwDataStart = rwDataStart.min(index) - rwDataEnd = rwDataEnd.max(index + typ.size) + rwDataEnd = rwDataEnd.max(index + typ.alignedSize) } val altName = m.name.stripPrefix(env.prefix) + "`" env.things += altName -> ConstantThing(altName, NumericConstant(index, 2), env.get[Type]("pointer")) diff --git a/src/main/scala/millfork/output/MemoryAlignment.scala b/src/main/scala/millfork/output/MemoryAlignment.scala index d082fc89..9633d87d 100644 --- a/src/main/scala/millfork/output/MemoryAlignment.scala +++ b/src/main/scala/millfork/output/MemoryAlignment.scala @@ -1,20 +1,38 @@ package millfork.output +import millfork.MathUtils + /** * @author Karol Stasiak */ sealed trait MemoryAlignment { def isMultiplePages: Boolean + def roundSizeUp(size: Int): Int + def &(other: MemoryAlignment): MemoryAlignment } case object NoAlignment extends MemoryAlignment { override def isMultiplePages: Boolean = false + override def roundSizeUp(size: Int): Int = size + override def &(other: MemoryAlignment): MemoryAlignment = other } case object WithinPageAlignment extends MemoryAlignment { override def isMultiplePages: Boolean = false + override def roundSizeUp(size: Int): Int = size + override def &(other: MemoryAlignment): MemoryAlignment = other match { + case NoAlignment | WithinPageAlignment => this + case _ => throw new IllegalArgumentException(s"Cannot use incompatible alignments $this and $other simultaneously") + } } case class DivisibleAlignment(divisor: Int) extends MemoryAlignment { override def isMultiplePages: Boolean = divisor > 256 + override def roundSizeUp(size: Int): Int = + if (size % divisor == 0) size else size + divisor - size % divisor + override def &(other: MemoryAlignment): MemoryAlignment = other match { + case NoAlignment => this + case DivisibleAlignment(d) => DivisibleAlignment(MathUtils.lcm(divisor, d)) + case _ => throw new IllegalArgumentException(s"Cannot use incompatible alignments $this and $other simultaneously") + } } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index fcbe3987..cec3fe00 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -725,17 +725,19 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri p <- position() _ <- "struct" ~ !letterOrDigit ~/ SWS ~ position("struct name") name <- identifier ~/ HWS - _ <- position("struct defintion block") + align <- alignmentDeclaration(NoAlignment).? ~/ HWS + _ <- position("struct definition block") fields <- compoundTypeFields ~/ Pass - } yield Seq(StructDefinitionStatement(name, fields).pos(p)) + } yield Seq(StructDefinitionStatement(name, fields, align).pos(p)) val unionDefinition: P[Seq[UnionDefinitionStatement]] = for { p <- position() _ <- "union" ~ !letterOrDigit ~/ SWS ~ position("union name") name <- identifier ~/ HWS - _ <- position("union defintion block") + align <- alignmentDeclaration(NoAlignment).? ~/ HWS + _ <- position("union definition block") fields <- compoundTypeFields ~/ Pass - } yield Seq(UnionDefinitionStatement(name, fields).pos(p)) + } yield Seq(UnionDefinitionStatement(name, fields, align).pos(p)) val segmentBlock: P[Seq[BankedDeclarationStatement]] = for { (_, bankName) <- "segment" ~ AWS ~ "(" ~ AWS ~ position("segment name") ~ identifier ~ AWS ~ ")" ~ AWS ~ "{" ~/ AWS diff --git a/src/test/scala/millfork/test/ArraySuite.scala b/src/test/scala/millfork/test/ArraySuite.scala index 71d18890..9044e5a4 100644 --- a/src/test/scala/millfork/test/ArraySuite.scala +++ b/src/test/scala/millfork/test/ArraySuite.scala @@ -641,4 +641,60 @@ class ArraySuite extends FunSuite with Matchers with AppendedClues { | } """.stripMargin) } + + test("Arrays of aligned structs") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Intel8080, Cpu.Z80, Cpu.Motorola6809)( + """ + | struct s align(2) { byte x } + | array(s) a[7] + | word output @$c000 + | byte output2 @$c002 + | void main () { + | output = a[1].pointer - a[0].pointer + | a[f(3)].x = 5 + | output2 = a[3].x + | } + | noinline byte f(byte x) = x + """.stripMargin) { m => + m.readWord(0xc000) should equal(2) + m.readByte(0xc002) should equal(5) + } + } + + test("Accessing large fields of structs in arrays") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos/*, Cpu.Intel8080, Cpu.Z80, Cpu.Motorola6809*/)( + """ + | struct s { word x, word y} + | array(s) a[7] + | word output @$c000 + | + | void main () { + | a[f(4)].y = 5 + | a[f(4)].y = a[f(4)].y + | output = a[f(4)].y + | } + | noinline byte f(byte x) = x + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + } + } + + test("Accessing large fields of structs in arrays 2") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos/*, Cpu.Intel8080, Cpu.Z80, Cpu.Motorola6809*/)( + """ + | struct s { word x, word y} + | array(s) a[7] @$c000 + | + | void main () { + | byte t + | t = f(4) + | a[t+4].x = t << 3 + | a[t+4].y = t ^ 6 + | } + | noinline byte f(byte x) = x + """.stripMargin) { m => + m.readWord(0xc020) should equal(32) + m.readWord(0xc022) should equal(2) + } + } } diff --git a/src/test/scala/millfork/test/MathUtilsSuite.scala b/src/test/scala/millfork/test/MathUtilsSuite.scala new file mode 100644 index 00000000..e0219f45 --- /dev/null +++ b/src/test/scala/millfork/test/MathUtilsSuite.scala @@ -0,0 +1,23 @@ +package millfork.test + +import millfork.MathUtils +import millfork.MathUtils.gcd +import org.scalatest.{AppendedClues, FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class MathUtilsSuite extends FunSuite with Matchers with AppendedClues { + + test("GCD") { + gcd(2, 4) should equal(2) + gcd(4, 2) should equal(2) + gcd(5, 1) should equal(1) + gcd(5, 5) should equal(5) + gcd(0, 5) should equal(5) + gcd(5, 0) should equal(5) + gcd(9, 12) should equal(3) + gcd(12, 9) should equal(3) + } + +}