From 029e84b0f08e1e2ea2aa2a9ef399a199e6231c0d Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 15 Apr 2019 19:45:26 +0200 Subject: [PATCH] Unions, typed pointers, indirect field access via pointers --- CHANGELOG.md | 4 +- docs/lang/operators.md | 8 + docs/lang/types.md | 34 ++++ .../scala/millfork/CompilationOptions.scala | 1 + .../compiler/AbstractExpressionCompiler.scala | 23 +++ .../AbstractStatementPreprocessor.scala | 61 +++++- .../millfork/compiler/mos/BuiltIns.scala | 56 +++++- .../compiler/mos/MosExpressionCompiler.scala | 182 ++++++++++++++++++ .../compiler/z80/Z80ExpressionCompiler.scala | 90 +++++++++ .../millfork/compiler/z80/ZBuiltIns.scala | 4 + src/main/scala/millfork/env/Environment.scala | 126 ++++++++++-- src/main/scala/millfork/env/Thing.scala | 22 ++- src/main/scala/millfork/node/Node.scala | 52 ++++- .../node/opt/UnusedLocalVariables.scala | 3 + src/main/scala/millfork/parser/MfParser.scala | 50 ++++- src/test/scala/millfork/test/DerefSuite.scala | 86 +++++++++ .../scala/millfork/test/PointerSuite.scala | 115 ++++++++++- .../scala/millfork/test/StructSuite.scala | 20 +- src/test/scala/millfork/test/emu/EmuRun.scala | 1 + .../scala/millfork/test/emu/EmuZ80Run.scala | 1 + 20 files changed, 896 insertions(+), 43 deletions(-) create mode 100644 src/test/scala/millfork/test/DerefSuite.scala diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c132bd8..5c3ae4c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,12 @@ * Added `MILLFORK_VERSION` preprocessor parameter. -* Added structs. +* Added structs and unions. * Pointers can now be allocated anywhere. +* Pointers can now be typed. + * Arrays can now have elements of types other than `byte` (still limited in size to 1 byte though). * Added hint for identifiers with typos. diff --git a/docs/lang/operators.md b/docs/lang/operators.md index 6e39913e..6b99cc98 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -65,6 +65,12 @@ Such expressions have the property that the only register they may clobber is Y. Expressions of the shape `h:l` where `h` and `l` are of type byte, are considered expressions of type word. If and only if both `h` and `l` are assignable expressions, then `h:l` is also an assignable expression. +## Indirect field access operator + +`->` + +TODO + ## Binary arithmetic operators * `+`, `-`: @@ -133,12 +139,14 @@ Note you cannot mix those operators, so `a <= b < c` is not valid. `enum == enum` `byte == byte` `simple word == simple word` +`word == constant` `simple long == simple long` * `!=`: inequality `enum != enum` `byte != byte` `simple word != simple word` +`word != constant` `simple long != simple long` * `>`, `<`, `<=`, `>=`: inequality diff --git a/docs/lang/types.md b/docs/lang/types.md index 5a56929a..495208a2 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -50,6 +50,23 @@ Numeric types can be converted automatically: * from a type of defined signedness to a type of undefined signedness (`sbyte`→`byte`) +## Typed pointers + +For every type `T`, there is a pointer type defined called `pointer.T`. + +Unlike raw pointers, they are not subject to arithmetic. + +Examples: + + pointer.t p + p.raw // expression of type pointer, pointing to the same location in memory as 'p' + p.lo // equivalent to 'p.raw.lo' + p.hi // equivalent to 'p.raw.lo' + p[0] // valid only if the type 't' is of size 1 or 2, accesses the pointed element + p[i] // valid only if the type 't' is of size 1, equivalent to 't(p.raw[i])' + p->x // valid only if the type 't' has a field called 'x', accesses the field 'x' of the pointed element + p->x.y->z // you can stack it + ## Boolean types TODO @@ -115,3 +132,20 @@ Offsets are available as `structname.fieldname.offset`: // alternatively: ptr = p.y.addr + +## Unions + + union { } + +Unions are pretty similar to structs, with the difference that all fields of the union +start at the same point in memory and therefore overlap each other. + + struct point { byte x, byte y } + union point_or_word { point p, word w } + + point_or_word u + u.p.x = 0 + u.p.y = 0 + if u.w == 0 { ok() } + +Offset constants are also available, but they're obviously all zero. diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala index 9c87f2ee..1516e7fa 100644 --- a/src/main/scala/millfork/CompilationOptions.scala +++ b/src/main/scala/millfork/CompilationOptions.scala @@ -347,6 +347,7 @@ object CompilationFlag extends Enumeration { NonZeroTerminatedLiteralWarning, FatalWarnings, // special options for internal compiler use + EnableInternalTestSyntax, InternalCurrentlyOptimizingForMeasurement = Value val allWarnings: Set[CompilationFlag.Value] = Set(ExtraComparisonWarnings) diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 4e7ada03..c2b6261f 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -254,6 +254,26 @@ object AbstractExpressionCompiler { b case IndexedExpression(name, _) => env.getPointy(name).elementType + case DerefDebuggingExpression(_, 1) => b + case DerefDebuggingExpression(_, 2) => w + case DerefExpression(_, _, typ) => typ + case IndirectFieldExpression(inner, fieldPath) => + val firstPointerType = getExpressionType(env, log, inner) + fieldPath.foldLeft(firstPointerType) { (currentType, fieldName) => + currentType match { + case PointerType(_, _, Some(targetType)) => + val tuples = env.getSubvariables(targetType).filter(x => x._1 == "." + fieldName) + if (tuples.isEmpty) { + log.error(s"Type `$targetType` doesn't have field named `$fieldName`", expr.position) + b + } else { + tuples.head._3 + } + case _ => + log.error(s"Type `$currentType` is not a pointer type", expr.position) + b + } + } case SeparateBytesExpression(hi, lo) => if (getExpressionType(env, log, hi).size > 1) log.error("Hi byte too large", hi.position) if (getExpressionType(env, log, lo).size > 1) log.error("Lo byte too large", lo.position) @@ -271,6 +291,9 @@ object AbstractExpressionCompiler { } case FunctionCallExpression("hi", params) => b case FunctionCallExpression("lo", params) => b + case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionType(env, log, params(1)) + case FunctionCallExpression("cos", params) => if (params.size < 2) b else getExpressionType(env, log, params(1)) + case FunctionCallExpression("tan", params) => if (params.size < 2) b else getExpressionType(env, log, params(1)) case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match { case 1 => b case 2 => w diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index 1bbc2ac9..b45a1d70 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -57,6 +57,15 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements def maybeOptimizeForStatement(f: ForStatement): Option[(ExecutableStatement, VV)] + def isNonzero(index: Expression): Boolean = env.eval(index) match { + case Some(c) => !c.isProvablyZero + case _ => true + } + + def isWordPointy(name: String): Boolean = { + env.getPointy(name).elementType.size == 2 + } + def optimizeStmt(stmt: ExecutableStatement, currentVarValues: VV): (ExecutableStatement, VV) = { var cv = currentVarValues val pos = stmt.position @@ -105,6 +114,29 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements case Some(c) => cv + (v -> c) case None => cv - v }) + case Assignment(target:DerefDebuggingExpression, arg) => + cv = search(arg, cv) + cv = search(target, cv) + Assignment(optimizeExpr(target, cv).asInstanceOf[LhsExpression], optimizeExpr(arg, cv)).pos(pos) -> cv + case Assignment(target:DerefExpression, arg) => + cv = search(arg, cv) + cv = search(target, cv) + Assignment(optimizeExpr(target, cv).asInstanceOf[LhsExpression], optimizeExpr(arg, cv)).pos(pos) -> cv + case Assignment(target:IndirectFieldExpression, arg) => + cv = search(arg, cv) + cv = search(target, cv) + Assignment(optimizeExpr(target, cv).asInstanceOf[LhsExpression], optimizeExpr(arg, cv)).pos(pos) -> cv + case Assignment(target:IndexedExpression, arg) if isWordPointy(target.name) => + if (isNonzero(target.index)) { + ctx.log.error("Pointers to word variables can be only indexed by 0") + } + cv = search(arg, cv) + cv = search(target, cv) + Assignment(DerefExpression(VariableExpression(target.name).pos(pos), 0, env.getPointy(target.name).elementType).pos(pos), optimizeExpr(arg, cv)).pos(pos) -> cv + case Assignment(target:IndexedExpression, arg) => + cv = search(arg, cv) + cv = search(target, cv) + Assignment(optimizeExpr(target, cv).asInstanceOf[LhsExpression], optimizeExpr(arg, cv)).pos(pos) -> cv case Assignment(ve, arg) => cv = search(arg, cv) cv = search(ve, cv) @@ -180,6 +212,8 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements case SumExpression(params, _) => params.map(p => search(p._2, cv)).reduce(commonVV) case HalfWordExpression(arg, _) => search(arg, cv) case IndexedExpression(_, arg) => search(arg, cv) + case DerefDebuggingExpression(arg, _) => search(arg, cv) + case DerefExpression(arg, _, _) => search(arg, cv) case _ => cv // TODO } } @@ -243,10 +277,29 @@ abstract class AbstractStatementPreprocessor(ctx: CompilationContext, statements case _ => } expr match { - case FunctionCallExpression("->", List(handle, VariableExpression(field))) => - expr - case FunctionCallExpression("->", List(handle, FunctionCallExpression(method, params))) => - expr + case IndirectFieldExpression(root, fieldPath) if AbstractExpressionCompiler.getExpressionType(env, env.log, root).isInstanceOf[PointerType] => + fieldPath.foldLeft(root) { (pointer, fieldName) => + AbstractExpressionCompiler.getExpressionType(env, env.log, pointer) match { + case PointerType(_, _, Some(target)) => + val subvariables = env.getSubvariables(target).filter(x => x._1 == "." + fieldName) + if (subvariables.isEmpty) { + ctx.log.error(s"Type `${target.name}` does not contain field `$fieldName`", pointer.position) + LiteralExpression(0, 1) + } else { + DerefExpression(optimizeExpr(pointer, currentVarValues).pos(pos), subvariables.head._2, subvariables.head._3) + } + case _ => + ctx.log.error("Invalid pointer type on the left-hand side of `->`", pointer.position) + LiteralExpression(0, 1) + } + } + case IndirectFieldExpression(root, fieldPath) => + ctx.log.error("Invalid pointer type on the left-hand side of `->`", pos) + root + case DerefDebuggingExpression(inner, 1) => + DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("byte")).pos(pos) + case DerefDebuggingExpression(inner, 2) => + DerefExpression(optimizeExpr(inner, currentVarValues), 0, env.get[VariableType]("word")).pos(pos) case TextLiteralExpression(characters) => val name = genName(characters) if (ctx.env.maybeGet[Thing](name).isEmpty) { diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 68047505..58cfa3bd 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -1,6 +1,6 @@ package millfork.compiler.mos -import millfork.CompilationFlag +import millfork.{CompilationFlag, assembly} import millfork.assembly.Elidability import millfork.assembly.mos.AddrMode._ import millfork.assembly.mos.Opcode._ @@ -327,6 +327,10 @@ object BuiltIns { } def compileInPlaceWordOrLongShiftOps(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, aslRatherThanLsr: Boolean): List[AssemblyLine] = { + if (lhs.isInstanceOf[DerefExpression]) { + ctx.log.error("Too complex left-hand-side expression") + return MosExpressionCompiler.compileToAX(ctx, lhs) ++ MosExpressionCompiler.compileToAX(ctx, rhs) + } val env = ctx.env val b = env.get[Type]("byte") val targetBytes = getStorageForEachByte(ctx, lhs) @@ -579,7 +583,7 @@ object BuiltIns { case BranchIfTrue(label) => compType -> label case BranchIfFalse(label) => ComparisonType.negate(compType) -> label } - val (lh, ll, rh, rl) = (lhs, env.eval(lhs), rhs, env.eval(rhs)) match { + val (preparations, lh, ll, rh, rl) = (lhs, env.eval(lhs), rhs, env.eval(rhs)) match { case (_, Some(NumericConstant(lc, _)), _, Some(NumericConstant(rc, _))) => return if (effectiveComparisonType match { // TODO: those masks are probably wrong @@ -613,21 +617,51 @@ object BuiltIns { return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches) case (v: VariableExpression, None, _, Some(rc)) => val lva = env.get[VariableInMemory](v.name) - (AssemblyLine.variable(ctx, CMP, lva, 1), + (Nil, + AssemblyLine.variable(ctx, CMP, lva, 1), AssemblyLine.variable(ctx, CMP, lva, 0), List(AssemblyLine.immediate(CMP, rc.hiByte.quickSimplify)), List(AssemblyLine.immediate(CMP, rc.loByte.quickSimplify))) case (lv: VariableExpression, None, rv: VariableExpression, None) => val lva = env.get[VariableInMemory](lv.name) val rva = env.get[VariableInMemory](rv.name) - (AssemblyLine.variable(ctx, CMP, lva, 1), + (Nil, + AssemblyLine.variable(ctx, CMP, lva, 1), AssemblyLine.variable(ctx, CMP, lva, 0), AssemblyLine.variable(ctx, CMP, rva, 1), AssemblyLine.variable(ctx, CMP, rva, 0)) + case (expr, None, _, Some(constant)) if effectiveComparisonType == ComparisonType.Equal => + val innerLabel = ctx.nextLabel("cp") + return MosExpressionCompiler.compileToAX(ctx, expr) ++ List( + AssemblyLine.immediate(CMP, constant.loByte), + AssemblyLine.relative(BNE, innerLabel), + AssemblyLine.immediate(CPX, constant.hiByte), + AssemblyLine.relative(BEQ, Label(x)), + AssemblyLine.label(innerLabel)) + case (_, Some(constant), expr, None) if effectiveComparisonType == ComparisonType.Equal => + val innerLabel = ctx.nextLabel("cp") + return MosExpressionCompiler.compileToAX(ctx, expr) ++ List( + AssemblyLine.immediate(CMP, constant.loByte), + AssemblyLine.relative(BNE, innerLabel), + AssemblyLine.immediate(CPX, constant.hiByte), + AssemblyLine.relative(BEQ, Label(x)), + AssemblyLine.label(innerLabel)) + case (expr, None, _, Some(constant)) if effectiveComparisonType == ComparisonType.NotEqual => + return MosExpressionCompiler.compileToAX(ctx, expr) ++ List( + AssemblyLine.immediate(CMP, constant.loByte), + AssemblyLine.relative(BNE, Label(x)), + AssemblyLine.immediate(CPX, constant.hiByte), + AssemblyLine.relative(BNE, Label(x))) + case (_, Some(constant), expr, None) if effectiveComparisonType == ComparisonType.NotEqual => + return MosExpressionCompiler.compileToAX(ctx, expr) ++ List( + AssemblyLine.immediate(CMP, constant.loByte), + AssemblyLine.relative(BNE, Label(x)), + AssemblyLine.immediate(CPX, constant.hiByte), + AssemblyLine.relative(BNE, Label(x))) case _ => // TODO comparing expressions ctx.log.error("Too complex expressions in comparison", lhs.position.orElse(rhs.position)) - (Nil, Nil, Nil, Nil) + (Nil, Nil, Nil, Nil, Nil) } val lType = MosExpressionCompiler.getExpressionType(ctx, lhs) val rType = MosExpressionCompiler.getExpressionType(ctx, rhs) @@ -865,6 +899,10 @@ object BuiltIns { private def isPowerOfTwoUpTo15(n: Long): Boolean = if (n <= 0 || n >= 0x8000) false else 0 == ((n-1) & n) def compileInPlaceWordMultiplication(ctx: CompilationContext, v: LhsExpression, addend: Expression): List[AssemblyLine] = { + if (v.isInstanceOf[DerefExpression]) { + ctx.log.error("Too complex left-hand-side expression") + return MosExpressionCompiler.compileToAX(ctx, v) ++ MosExpressionCompiler.compileToAX(ctx, addend) + } val b = ctx.env.get[Type]("byte") val w = ctx.env.get[Type]("word") ctx.env.eval(addend) match { @@ -1002,6 +1040,10 @@ object BuiltIns { ctx.log.error("Unsupported decimal operation. Consider increasing the size of the zeropage register.", lhs.position) return compileInPlaceWordOrLongAddition(ctx, lhs, addend, subtract, decimal = false) } + if (lhs.isInstanceOf[DerefExpression]) { + ctx.log.error("Too complex left-hand-side expression") + return MosExpressionCompiler.compileToAX(ctx, lhs) ++ MosExpressionCompiler.compileToAX(ctx, addend) + } val env = ctx.env val b = env.get[Type]("byte") val w = env.get[Type]("word") @@ -1327,6 +1369,10 @@ object BuiltIns { def compileInPlaceWordOrLongBitOp(ctx: CompilationContext, lhs: LhsExpression, param: Expression, operation: Opcode.Value): List[AssemblyLine] = { + if (lhs.isInstanceOf[DerefExpression]) { + ctx.log.error("Too complex left-hand-side expression") + return MosExpressionCompiler.compileToAX(ctx, lhs) ++ MosExpressionCompiler.compileToAX(ctx, param) + } val env = ctx.env val b = env.get[Type]("byte") val w = env.get[Type]("word") diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 7024d326..75d9d8a8 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -396,6 +396,15 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { ctx.log.error("Invalid index for writing", indexExpr.position) Nil } + case DerefExpression(inner, offset, targetType) => + val (prepare, reg) = getPhysicalPointerForDeref(ctx, inner) + val lo = preserveRegisterIfNeeded(ctx, MosRegister.A, prepare) ++ List(AssemblyLine.immediate(LDY, offset), AssemblyLine.indexedY(STA, reg)) + if (targetType.size == 1) { + lo + } else { + lo ++ List(AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(targetType.size - 1)(i => List(AssemblyLine.implied(INY), AssemblyLine.indexedY(STA, reg))).flatten + } } } val noop: List[AssemblyLine] = Nil @@ -407,6 +416,28 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { compile(ctx, expr, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None) } + def compileToAX(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = { + val env = ctx.env + val w = env.get[Type]("word") + compile(ctx, expr, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None) + } + + def compileToZReg(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = { + val env = ctx.env + val p = env.get[Type]("pointer") + compile(ctx, expr, Some(p -> env.get[Variable]("__reg.loword")), BranchSpec.None) + } + + def getPhysicalPointerForDeref(ctx: CompilationContext, pointerExpression: Expression): (List[AssemblyLine], ThingInMemory) = { + pointerExpression match { + case VariableExpression(name) => + val p = ctx.env.get[ThingInMemory](name) + if (p.zeropage) return Nil -> p + case _ => + } + compileToZReg(ctx, pointerExpression) -> ctx.env.get[ThingInMemory]("__reg.loword") + } + def compile(ctx: CompilationContext, expr: Expression, exprTypeAndVariable: Option[(Type, Variable)], branches: BranchSpec): List[AssemblyLine] = { val env = ctx.env env.eval(expr) match { @@ -815,6 +846,24 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case MosRegister.AY => result :+ AssemblyLine.immediate(LDY, 0) } } + case DerefExpression(inner, offset, targetType) => + val (prepare, reg) = getPhysicalPointerForDeref(ctx, inner) + targetType.size match { + case 1 => + prepare ++ List(AssemblyLine.immediate(LDY, offset), AssemblyLine.indexedY(LDA, reg)) ++ expressionStorageFromA(ctx, exprTypeAndVariable, expr.position) + case 2 => + prepare ++ + List( + AssemblyLine.immediate(LDY, offset+1), + AssemblyLine.indexedY(LDA, reg), + AssemblyLine.implied(TAX), + AssemblyLine.implied(DEY), + AssemblyLine.indexedY(LDA, reg)) ++ + expressionStorageFromAX(ctx, exprTypeAndVariable, expr.position) + case _ => + ctx.log.error("Cannot read a large object indirectly") + Nil + } case SumExpression(params, decimal) => assertAllArithmetic(ctx, params.map(_._2)) val a = params.map{case (n, p) => env.eval(p).map(n -> _)} @@ -1508,6 +1557,75 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } } + def expressionStorageFromA(ctx: CompilationContext, exprTypeAndVariable: Option[(Type, Variable)], position: Option[Position]): List[AssemblyLine] = { + exprTypeAndVariable.fold(noop) { + case (VoidType, _) => ctx.log.fatal("Cannot assign word to void", position) + case (_, RegisterVariable(MosRegister.A, _)) => noop + case (typ, RegisterVariable(MosRegister.AW, _)) => + if (typ.isSigned) List(AssemblyLine.implied(TAX)) ++ signExtendA(ctx) ++ List(AssemblyLine.implied(XBA), AssemblyLine.implied(TXA)) + else List(AssemblyLine.implied(XBA), AssemblyLine.immediate(LDA, 0), AssemblyLine.implied(XBA)) + case (_, RegisterVariable(MosRegister.X, _)) => List(AssemblyLine.implied(TAX)) + case (_, RegisterVariable(MosRegister.Y, _)) => List(AssemblyLine.implied(TAY)) + case (typ, RegisterVariable(MosRegister.AX, _)) => + if (typ.isSigned) { + if (ctx.options.flag(CompilationFlag.EmitHudsonOpcodes)) { + List(AssemblyLine.implied(TAX)) ++ signExtendA(ctx) ++ List(AssemblyLine.implied(HuSAX)) + } else { + List(AssemblyLine.implied(PHA)) ++ signExtendA(ctx) ++ List(AssemblyLine.implied(TAX), AssemblyLine.implied(PLA)) + } + } else List(AssemblyLine.immediate(LDX, 0)) + case (typ, RegisterVariable(MosRegister.XA, _)) => + if (typ.isSigned) { + List(AssemblyLine.implied(TAX)) ++ signExtendA(ctx) + } else { + List(AssemblyLine.implied(TAX), AssemblyLine.immediate(LDA, 0)) + } + case (typ, RegisterVariable(MosRegister.YA, _)) => + if (typ.isSigned) { + List(AssemblyLine.implied(TAY)) ++ signExtendA(ctx) + } else { + List(AssemblyLine.implied(TAY), AssemblyLine.immediate(LDA, 0)) + } + case (typ, RegisterVariable(MosRegister.AY, _)) => + if (typ.isSigned) { + if (ctx.options.flag(CompilationFlag.EmitHudsonOpcodes)) { + List(AssemblyLine.implied(TAY)) ++ signExtendA(ctx) ++ List(AssemblyLine.implied(SAY)) + } else { + List(AssemblyLine.implied(PHA)) ++ signExtendA(ctx) ++ List(AssemblyLine.implied(TAY), AssemblyLine.implied(PLA)) + } + } else List(AssemblyLine.immediate(LDY, 0)) + case (t, v: VariableInMemory) => + v.typ.size match { + case 1 => + AssemblyLine.variable(ctx, STA, v) + case s if s > 1 => + if (t.isSigned) { + AssemblyLine.variable(ctx, STA, v) ++ signExtendA(ctx) ++ List.tabulate(s - 1)(i => AssemblyLine.variable(ctx, STA, v, i + 1)).flatten + } else { + AssemblyLine.variable(ctx, STA, v) ++ List(AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(s - 1)(i => AssemblyLine.variable(ctx, STA, v, i + 1)).flatten + } + } + case (t, v: StackVariable) => + v.typ.size match { + case 1 => + AssemblyLine.tsx(ctx) :+ AssemblyLine.dataStackX(ctx, STA, v) + case s if s > 1 => + AssemblyLine.tsx(ctx) ++ (if (t.isSigned) { + List( + AssemblyLine.dataStackX(ctx, STA, v.baseOffset)) ++ + signExtendA(ctx) ++ + List.tabulate(s - 1)(i => AssemblyLine.dataStackX(ctx, STA, v, i + 1)) + } else { + List( + AssemblyLine.dataStackX(ctx, STA, v.baseOffset), + AssemblyLine.immediate(LDA, 0)) ++ + List.tabulate(s - 1)(i => AssemblyLine.dataStackX(ctx, STA, v, i + 1)) + }) + } + } + } + def expressionStorageFromAW(ctx: CompilationContext, exprTypeAndVariable: Option[(Type, Variable)], position: Option[Position]): List[AssemblyLine] = { exprTypeAndVariable.fold(noop) { case (VoidType, _) => ctx.log.fatal("Cannot assign word to void", position) @@ -1603,6 +1721,70 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case SeparateBytesExpression(_, _) => ctx.log.error("Invalid left-hand-side use of `:`") Nil + case DerefExpression(inner, offset, targetType) => + val (prepare, reg) = getPhysicalPointerForDeref(ctx, inner) + env.eval(source) match { + case Some(constant) => + targetType.size match { + case 1 => + prepare ++ List( + AssemblyLine.immediate(LDY, offset), + AssemblyLine.immediate(LDA, constant), + AssemblyLine.indexedY(STA, reg)) + case 2 => + prepare ++ List( + AssemblyLine.immediate(LDY, offset), + AssemblyLine.immediate(LDA, constant.loByte), + AssemblyLine.indexedY(STA, reg), + AssemblyLine.implied(INY), + AssemblyLine.immediate(LDA, constant.hiByte), + AssemblyLine.indexedY(STA, reg)) + } + case None => + source match { + case VariableExpression(vname) => + val variable = env.get[Variable](vname) + targetType.size match { + case 1 => + prepare ++ + AssemblyLine.variable(ctx, LDA, variable) ++ List( + AssemblyLine.immediate(LDY, offset), + AssemblyLine.indexedY(STA, reg)) + case 2 => + prepare ++ + AssemblyLine.variable(ctx, LDA, variable) ++ List( + AssemblyLine.immediate(LDY, offset), + AssemblyLine.indexedY(STA, reg)) ++ + AssemblyLine.variable(ctx, LDA, variable, 1) ++ List( + AssemblyLine.implied(INY), + AssemblyLine.indexedY(STA, reg)) + case _ => + ctx.log.error("Cannot assign to a large object indirectly") + Nil + } + case _ => + targetType.size match { + case 1 => + compile(ctx, source, Some(targetType, RegisterVariable(MosRegister.A, targetType)), BranchSpec.None) ++ compileByteStorage(ctx, MosRegister.A, target) + case 2 => + val someTuple = Some(targetType, RegisterVariable(MosRegister.AX, targetType)) + // TODO: optimiza if prepare is empty + compile(ctx, source, someTuple, BranchSpec.None) ++ List( + AssemblyLine.implied(PHA), + AssemblyLine.implied(TXA), + AssemblyLine.implied(PHA)) ++ prepare ++ List( + AssemblyLine.immediate(LDY, offset+1), + AssemblyLine.implied(PLA), + AssemblyLine.indexedY(STA, reg), + AssemblyLine.implied(PLA), + AssemblyLine.implied(DEY), + AssemblyLine.indexedY(STA, reg)) + case _ => + ctx.log.error("Cannot assign to a large object indirectly") + Nil + } + } + } case _ => compile(ctx, source, Some(b, RegisterVariable(MosRegister.A, b)), NoBranching) ++ compileByteStorage(ctx, MosRegister.A, target) } diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index 2aebd34c..e8917eb2 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -43,6 +43,14 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { def compileToHL(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.HL) + def compileDerefPointer(ctx: CompilationContext, expression: DerefExpression): List[ZLine] = { + compileToHL(ctx, expression.inner) ++ (expression.offset match { + case 0 => Nil + case i if i < 5 => List.fill(i)(ZLine.register(INC_16, ZRegister.HL)) // TODO: a better threshold + case _ => List(ZLine.ldImm8(ZRegister.C, expression.offset), ZLine.ldImm8(ZRegister.B, 0), ZLine.registers(ADD_16, ZRegister.HL, ZRegister.BC)) + }) + } + def compileToEHL(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.EHL) def compileToDEHL(ctx: CompilationContext, expression: Expression): List[ZLine] = compile(ctx, expression, ZExpressionTarget.DEHL) @@ -506,6 +514,69 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case ZExpressionTarget.DE => xxx(ZRegister.D, ZRegister.E, allowRedirect = true) case ZExpressionTarget.BC => xxx(ZRegister.B, ZRegister.C, allowRedirect = true) } + case e@DerefExpression(inner, offset, targetType) => + compileDerefPointer(ctx, e) ++ (targetType.size match { + case 1 => target match { + case ZExpressionTarget.A => List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL)) + case ZExpressionTarget.DEHL => + if (targetType.isSigned) { + List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), ZLine.ld8(ZRegister.L, ZRegister.A)) ++ + signExtendHighestByte(ctx, ZRegister.A) ++ + List(ZLine.ld8(ZRegister.H, ZRegister.A), ZLine.ld8(ZRegister.E, ZRegister.A), ZLine.ld8(ZRegister.D, ZRegister.A)) + } else { + List(ZLine.ld8(ZRegister.L, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.H, 0), ZLine.ldImm8(ZRegister.E, 0), ZLine.ldImm8(ZRegister.D, 0)) + } + case ZExpressionTarget.HL => + if (targetType.isSigned) { + List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), ZLine.ld8(ZRegister.L, ZRegister.A)) ++ signExtendHighestByte(ctx, ZRegister.H) + } else { + List(ZLine.ld8(ZRegister.L, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.H, 0)) + } + case ZExpressionTarget.BC => + if (targetType.isSigned) { + List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), ZLine.ld8(ZRegister.C, ZRegister.A)) ++ signExtendHighestByte(ctx, ZRegister.B) + } else { + List(ZLine.ld8(ZRegister.C, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.B, 0)) + } + case ZExpressionTarget.DE => + if (targetType.isSigned) { + List(ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), ZLine.ld8(ZRegister.E, ZRegister.A)) ++ signExtendHighestByte(ctx, ZRegister.D) + } else { + List(ZLine.ld8(ZRegister.E, ZRegister.MEM_HL), ZLine.ldImm8(ZRegister.D, 0)) + } + case ZExpressionTarget.NOTHING => Nil + } + case 2 => target match { + case ZExpressionTarget.DEHL => + List( + ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.H, ZRegister.MEM_HL), + ZLine.ld8(ZRegister.L, ZRegister.A), + ZLine.ldImm8(ZRegister.E, 0), + ZLine.ldImm8(ZRegister.D, 0)) // TODO + case ZExpressionTarget.HL => + List( + ZLine.ld8(ZRegister.A, ZRegister.MEM_HL), + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.H, ZRegister.MEM_HL), + ZLine.ld8(ZRegister.L, ZRegister.A)) + case ZExpressionTarget.BC => + List( + ZLine.ld8(ZRegister.C, ZRegister.MEM_HL), + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.B, ZRegister.MEM_HL)) + case ZExpressionTarget.DE => + List( + ZLine.ld8(ZRegister.E, ZRegister.MEM_HL), + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.D, ZRegister.MEM_HL)) + case ZExpressionTarget.NOTHING => Nil + } + case _ => + ctx.log.error("Cannot read a large object indirectly") + Nil + }) case f@FunctionCallExpression(name, params) => name match { case "not" => @@ -965,6 +1036,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } } case i:IndexedExpression => Some(LocalVariableAddressViaHL -> calculateAddressToHL(ctx, i)) + case i:DerefExpression => Some(LocalVariableAddressViaHL -> compileDerefPointer(ctx, i)) case _:SeparateBytesExpression => None case _ => ??? } @@ -1258,6 +1330,18 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ stashDEIfChanged(ctx, code) :+ ZLine.ld8(ZRegister.MEM_HL, ZRegister.E) } else code :+ ZLine.ld8(ZRegister.MEM_HL, ZRegister.A) } + case e@DerefExpression(_, _, targetType) => + val lo = stashAFIfChanged(ctx, compileDerefPointer(ctx, e)) :+ ZLine.ld8(ZRegister.MEM_HL, ZRegister.A) + if (targetType.size == 1) lo + else if (targetType.size == 2) { + lo ++ List(ZLine.register(INC_16, ZRegister.HL)) ++ signExtendHighestByte(ctx, ZRegister.MEM_HL) + lo + } else { + lo ++ signExtendHighestByte(ctx, ZRegister.A) ++ List.tabulate(targetType.size - 1)(_ => List( + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.MEM_HL, ZRegister.A) + )).flatten + } //TODO case SeparateBytesExpression(hi, lo) => ??? } @@ -1308,6 +1392,12 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case SeparateBytesExpression(hi: LhsExpression, lo: LhsExpression) => Z80ExpressionCompiler.stashHLIfChanged(ctx, ZLine.ld8(ZRegister.A, ZRegister.L) :: storeA(ctx, lo, signedSource)) ++ (ZLine.ld8(ZRegister.A, ZRegister.H) :: storeA(ctx, hi, signedSource)) + case e:DerefExpression => + List(ZLine.register(PUSH, ZRegister.HL)) ++ compileDerefPointer(ctx, e) ++ List( + ZLine.register(POP, ZRegister.BC), + ZLine.ld8(ZRegister.MEM_HL, ZRegister.C), + ZLine.register(INC_16, ZRegister.HL), + ZLine.ld8(ZRegister.MEM_HL, ZRegister.B)) case _: SeparateBytesExpression => ctx.log.error("Invalid `:`", target.position) Nil diff --git a/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala b/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala index 08faa034..d12d92d1 100644 --- a/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala +++ b/src/main/scala/millfork/compiler/z80/ZBuiltIns.scala @@ -450,6 +450,10 @@ object ZBuiltIns { def performLongInPlace(ctx: CompilationContext, lhs: LhsExpression, rhs: Expression, opcodeFirst: ZOpcode.Value, opcodeLater: ZOpcode.Value, size: Int, decimal: Boolean = false): List[ZLine] = { + if (lhs.isInstanceOf[DerefExpression]) { + ctx.log.error("Too complex left-hand-side expression") + return Z80ExpressionCompiler.compileToHL(ctx, lhs) ++ Z80ExpressionCompiler.compileToHL(ctx, rhs) + } if (size == 2 && !decimal) { // n × INC HL // 6n cycles, n bytes diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 916c5174..755c5621 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -226,6 +226,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa removedThings += str + ".addr" removedThings += str + ".addr.lo" removedThings += str + ".addr.hi" + removedThings += str + ".pointer" + removedThings += str + ".pointer.lo" + removedThings += str + ".pointer.hi" removedThings += str + ".rawaddr" removedThings += str + ".rawaddr.lo" removedThings += str + ".rawaddr.hi" @@ -233,6 +236,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa things -= str + ".addr" things -= str + ".addr.lo" things -= str + ".addr.hi" + things -= str + ".pointer" + things -= str + ".pointer.lo" + things -= str + ".pointer.hi" things -= str + ".rawaddr" things -= str + ".rawaddr.lo" things -= str + ".rawaddr.hi" @@ -240,6 +246,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa things -= str.stripPrefix(prefix) + ".addr" things -= str.stripPrefix(prefix) + ".addr.lo" things -= str.stripPrefix(prefix) + ".addr.hi" + things -= str.stripPrefix(prefix) + ".pointer" + things -= str.stripPrefix(prefix) + ".pointer.lo" + things -= str.stripPrefix(prefix) + ".pointer.hi" things -= str.stripPrefix(prefix) + ".rawaddr" things -= str.stripPrefix(prefix) + ".rawaddr.lo" things -= str.stripPrefix(prefix) + ".rawaddr.hi" @@ -247,6 +256,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = { + if (name.startsWith("pointer.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) { + val targetName = name.stripPrefix("pointer.") + val target = maybeGet[VariableType](targetName) + return PointerType(name, targetName, target).asInstanceOf[T] + } val clazz = implicitly[Manifest[T]].runtimeClass if (things.contains(name)) { val t: Thing = things(name) @@ -273,6 +287,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def root: Environment = parent.fold(this)(_.root) def maybeGet[T <: Thing : Manifest](name: String): Option[T] = { + if (name.startsWith("pointer.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) { + val targetName = name.stripPrefix("pointer.") + val target = maybeGet[VariableType](targetName) + return Some(PointerType(name, targetName, target)).asInstanceOf[Option[T]] + } if (things.contains(name)) { val t: Thing = things(name) val clazz = implicitly[Manifest[T]].runtimeClass @@ -310,17 +329,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case th@UninitializedArray(_, size, _, i, e, _) => ConstantPointy(th.toAddress, Some(name), Some(size), i, e, th.alignment) case th@RelativeArray(_, _, size, _, i, e) => ConstantPointy(th.toAddress, Some(name), Some(size), i, e, NoAlignment) case ConstantThing(_, value, typ) if typ.size <= 2 && typ.isPointy => - val b = get[VariableType]("byte") + val e = get[VariableType](typ.pointerTargetName) val w = get[VariableType]("word") - ConstantPointy(value, None, None, w, b, NoAlignment) + ConstantPointy(value, None, None, w, e, NoAlignment) case th:VariableInMemory if th.typ.isPointy=> - val b = get[VariableType]("byte") + val e = get[VariableType](th.typ.pointerTargetName) val w = get[VariableType]("word") - VariablePointy(th.toAddress, w, b, th.zeropage) + VariablePointy(th.toAddress, w, e, th.zeropage) case th:StackVariable if th.typ.isPointy => - val b = get[VariableType]("byte") + val e = get[VariableType](th.typ.pointerTargetName) val w = get[VariableType]("word") - StackVariablePointy(th.baseOffset, w, b) + StackVariablePointy(th.baseOffset, w, e) case _ => log.error(s"$name is not a valid pointer or array") val b = get[VariableType]("byte") @@ -517,6 +536,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case _ => maybeGet[ConstantThing](name).map(_.value) } case IndexedExpression(_, _) => None + case _: DerefExpression => None + case _: IndirectFieldExpression => None + case _: DerefDebuggingExpression => None case HalfWordExpression(param, hi) => evalImpl(e, vv).map(c => if (hi) c.hiByte else c.loByte) case SumExpression(params, decimal) => params.map { @@ -759,6 +781,15 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(StructType(stmt.name, stmt.fields), stmt.position) } + def registerUnion(stmt: UnionDefinitionStatement): Unit = { + stmt.fields.foreach{ f => + if (Environment.invalidFieldNames.contains(f._2)) { + log.error(s"Invalid field name: `${f._2}`", stmt.position) + } + } + addThing(UnionType(stmt.name, stmt.fields), stmt.position) + } + def getTypeSize(name: String, path: Set[String]): Int = { if (path.contains(name)) return -1 val t = get[Type](name) @@ -785,6 +816,26 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } sum } + case s: UnionType => + if (s.mutableSize >= 0) s.mutableSize + else { + val newPath = path + name + var max = 0 + for( (fieldType, _) <- s.fields) { + val fieldSize = getTypeSize(fieldType, newPath) + if (fieldSize < 0) return -1 + max = max max fieldSize + } + s.mutableSize = max + if (max > 0xff) { + log.error(s"Union `$name` is larger than 255 bytes") + } + val b = get[Type]("byte") + for ((fieldType, fieldName) <- s.fields) { + addThing(ConstantThing(s"$name.$fieldName.offset", NumericConstant(0, 1), b), None) + } + max + } case _ => t.size } } @@ -873,7 +924,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa stmt.bank ) addThing(mangled, stmt.position) - registerAddressConstant(mangled, stmt.position, options) + registerAddressConstant(mangled, stmt.position, options, None) addThing(ConstantThing(name + '`', addr, w), stmt.position) } @@ -939,12 +990,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa alignment = stmt.alignment.getOrElse(if (name == "main") NoAlignment else defaultFunctionAlignment(options, hot = true)) // TODO: decide actual hotness in a smarter way ) addThing(mangled, stmt.position) - registerAddressConstant(mangled, stmt.position, options) + registerAddressConstant(mangled, stmt.position, options, None) } } } - private def registerAddressConstant(thing: ThingInMemory, position: Option[Position], options: CompilationOptions): Unit = { + private def registerAddressConstant(thing: ThingInMemory, position: Option[Position], options: CompilationOptions, targetType: Option[Type]): Unit = { if (!thing.zeropage && options.flag(CompilationFlag.LUnixRelocatableCode)) { val b = get[Type]("byte") val w = get[Type]("word") @@ -953,6 +1004,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(relocatable, position) addThing(RelativeVariable(thing.name + ".addr.hi", addr + 1, b, zeropage = false, None, isVolatile = false), position) addThing(RelativeVariable(thing.name + ".addr.lo", addr, b, zeropage = false, None, isVolatile = false), position) + targetType.foreach {tt => + val typedPointer = RelativeVariable(thing.name + ".pointer", addr, PointerType("pointer."+tt.name, tt.name, Some(tt)), zeropage = false, None, isVolatile = false) + addThing(typedPointer, position) + addThing(RelativeVariable(thing.name + ".pointer.hi", addr + 1, b, zeropage = false, None, isVolatile = false), position) + addThing(RelativeVariable(thing.name + ".pointer.lo", addr, b, zeropage = false, None, isVolatile = false), position) + } val rawaddr = thing.toAddress addThing(ConstantThing(thing.name + ".rawaddr", rawaddr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".rawaddr.hi", rawaddr.hiByte, get[Type]("byte")), position) @@ -965,6 +1022,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(ConstantThing(thing.name + ".rawaddr", addr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".rawaddr.hi", addr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.lo", addr.loByte, get[Type]("byte")), position) + targetType.foreach { tt => + val pointerType = PointerType("pointer." + tt.name, tt.name, Some(tt)) + val typedPointer = RelativeVariable(thing.name + ".pointer", addr, pointerType, zeropage = false, None, isVolatile = false) + addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position) + addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, get[Type]("byte")), position) + addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, get[Type]("byte")), position) + } } } @@ -975,15 +1039,15 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val p = get[Type]("pointer") stmt.assemblyParamPassingConvention match { case ByVariable(name) => - val zp = typ.name == "pointer" // TODO + val zp = typ.isPointy // TODO val v = UninitializedMemoryVariable(prefix + name, typ, if (zp) VariableAllocationMethod.Zeropage else VariableAllocationMethod.Auto, None, defaultVariableAlignment(options, 2), isVolatile = false) addThing(v, stmt.position) - registerAddressConstant(v, stmt.position, options) + registerAddressConstant(v, stmt.position, options, Some(typ)) val addr = v.toAddress for((suffix, offset, t) <- getSubvariables(typ)) { val subv = RelativeVariable(v.name + suffix, addr + offset, t, zeropage = zp, None, isVolatile = v.isVolatile) addThing(subv, stmt.position) - registerAddressConstant(subv, stmt.position, options) + registerAddressConstant(subv, stmt.position, options, Some(t)) } case ByMosRegister(_) => () case ByZRegister(_) => () @@ -1093,7 +1157,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa declaredBank = stmt.bank, indexType, e) } addThing(array, stmt.position) - registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, stmt.bank, alignment, isVolatile = false), stmt.position, options) + registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, stmt.bank, alignment, isVolatile = false), stmt.position, options, Some(e)) val a = address match { case None => array.toAddress case Some(aa) => aa @@ -1165,7 +1229,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val array = InitializedArray(stmt.name + ".array", address, contents, declaredBank = stmt.bank, indexType, e, alignment) addThing(array, stmt.position) registerAddressConstant(UninitializedMemoryVariable(stmt.name, p, VariableAllocationMethod.None, - declaredBank = stmt.bank, alignment, isVolatile = false), stmt.position, options) + declaredBank = stmt.bank, alignment, isVolatile = false), stmt.position, options, Some(e)) val a = address match { case None => array.toAddress case Some(aa) => aa @@ -1249,7 +1313,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } InitializedMemoryVariable(name, None, typ, ive, declaredBank = stmt.bank, alignment, isVolatile = stmt.volatile) } - registerAddressConstant(v, stmt.position, options) + registerAddressConstant(v, stmt.position, options, Some(typ)) (v, v.toAddress) })(a => { val addr = eval(a).getOrElse(errorConstant(s"Address of `$name` has a non-constant value", position)) @@ -1259,7 +1323,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } val v = RelativeVariable(prefix + name, addr, typ, zeropage = zp, declaredBank = stmt.bank, isVolatile = stmt.volatile) - registerAddressConstant(v, stmt.position, options) + registerAddressConstant(v, stmt.position, options, Some(typ)) (v, addr) }) addThing(v, stmt.position) @@ -1269,7 +1333,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa for((suffix, offset, t) <- getSubvariables(typ)) { val subv = RelativeVariable(prefix + name + suffix, addr + offset, t, zeropage = v.zeropage, declaredBank = stmt.bank, isVolatile = v.isVolatile) addThing(subv, stmt.position) - registerAddressConstant(subv, stmt.position, options) + registerAddressConstant(subv, stmt.position, options, Some(t)) } } } @@ -1323,6 +1387,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa List.tabulate(sz){ i => (".b" + i, i, b) } case _ => Nil } + case p: PointerType => + List( + (".raw", 0, p), + (".raw.lo", 0, b), + (".raw.hi", 1, b), + (".lo", 0, b), + (".hi", 1, b)) case s: StructType => val builder = new ListBuffer[(String, Int, VariableType)] var offset = 0 @@ -1336,6 +1407,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa offset += typ.size } builder.toList + case s: UnionType => + val builder = new ListBuffer[(String, Int, VariableType)] + for((typeName, fieldName) <- s.fields) { + val typ = get[VariableType](typeName) + val suffix = "." + fieldName + builder += ((suffix, 0, typ)) + builder ++= getSubvariables(typ).map { + case (innerSuffix, innerOffset, innerType) => (suffix + innerSuffix, innerOffset, innerType) + } + } + builder.toList case _ => Nil } } @@ -1393,6 +1475,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def fixStructSizes(): Unit = { val allStructTypes = things.values.flatMap { case StructType(name, _) => Some(name) + case UnionType(name, _) => Some(name) case _ => None } var iterations = allStructTypes.size @@ -1423,6 +1506,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } program.declarations.foreach { case s: StructDefinitionStatement => registerStruct(s) + case s: UnionDefinitionStatement => registerUnion(s) case _ => } fixStructSizes() @@ -1518,6 +1602,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case IndexedExpression(name, index) => checkName[IndexableThing]("Array or pointer", name, node.position) nameCheck(index) + case DerefDebuggingExpression(inner, _) => + nameCheck(inner) + case DerefExpression(inner, _, _) => + nameCheck(inner) + case IndirectFieldExpression(inner, _) => + nameCheck(inner) case SeparateBytesExpression(h, l) => nameCheck(h) nameCheck(l) @@ -1560,5 +1650,5 @@ object Environment { "for", "if", "do", "while", "else", "return", "default", "to", "until", "paralleluntil", "parallelto", "downto", "inline", "noinline" ) ++ predefinedFunctions - val invalidFieldNames: Set[String] = Set("addr") + val invalidFieldNames: Set[String] = Set("addr", "rawaddr", "pointer") } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index eaec0d0f..662c68f3 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -34,6 +34,8 @@ sealed trait Type extends CallableThing { def isArithmetic = false def isPointy = false + + def pointerTargetName: String = "byte" } sealed trait VariableType extends Type @@ -70,13 +72,31 @@ case class DerivedPlainType(name: String, parent: PlainType, isSigned: Boolean, override def isSubtypeOf(other: Type): Boolean = parent == other || parent.isSubtypeOf(other) } +case class PointerType(name: String, targetName: String, var target: Option[Type]) extends VariableType { + def size = 2 + + override def isSigned: Boolean = false + + override def isPointy: Boolean = true + + override def pointerTargetName: String = targetName +} + case class EnumType(name: String, count: Option[Int]) extends VariableType { override def size: Int = 1 override def isSigned: Boolean = false } -case class StructType(name: String, fields: List[(String, String)]) extends VariableType { +sealed trait CompoundVariableType extends VariableType + +case class StructType(name: String, fields: List[(String, String)]) extends CompoundVariableType { + override def size: Int = mutableSize + var mutableSize: Int = -1 + override def isSigned: Boolean = false +} + +case class UnionType(name: String, fields: List[(String, String)]) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 override def isSigned: Boolean = false diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index dbda16de..b8f28614 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -182,7 +182,7 @@ case class VariableExpression(name: String) extends LhsExpression { override def replaceVariable(variable: String, actualParam: Expression): Expression = if (name == variable) actualParam else this override def containsVariable(variable: String): Boolean = name == variable - override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.takeWhile(_ != '.')) else Seq.empty + override def getPointies: Seq[String] = if (name.endsWith(".addr.lo")) Seq(name.stripSuffix(".addr.lo")) else Seq.empty override def isPure: Boolean = true override def getAllIdentifiers: Set[String] = Set(name) } @@ -202,8 +202,54 @@ case class IndexedExpression(name: String, index: Expression) extends LhsExpress override def getAllIdentifiers: Set[String] = index.getAllIdentifiers + name } +case class IndirectFieldExpression(root: Expression, fields: List[String]) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = IndirectFieldExpression(root.replaceVariable(variable, actualParam), fields) + + override def containsVariable(variable: String): Boolean = root.containsVariable(variable) + + override def getPointies: Seq[String] = root match { + case VariableExpression(v) => List(v) + case _ => root.getPointies + } + + override def isPure: Boolean = root.isPure + + override def getAllIdentifiers: Set[String] = root.getAllIdentifiers +} + +case class DerefDebuggingExpression(inner: Expression, preferredSize: Int) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefDebuggingExpression(inner.replaceVariable(variable, actualParam), preferredSize) + + override def containsVariable(variable: String): Boolean = inner.containsVariable(variable) + + override def getPointies: Seq[String] = inner match { + case VariableExpression(v) => List(v) + case _ => inner.getPointies + } + + override def isPure: Boolean = inner.isPure + + override def getAllIdentifiers: Set[String] = inner.getAllIdentifiers +} + +case class DerefExpression(inner: Expression, offset: Int, targetType: Type) extends LhsExpression { + override def replaceVariable(variable: String, actualParam: Expression): Expression = DerefExpression(inner.replaceVariable(variable, actualParam), offset, targetType) + + override def containsVariable(variable: String): Boolean = inner.containsVariable(variable) + + override def getPointies: Seq[String] = inner match { + case VariableExpression(v) => List(v) + case _ => inner.getPointies + } + + override def isPure: Boolean = inner.isPure + + override def getAllIdentifiers: Set[String] = inner.getAllIdentifiers +} + sealed trait Statement extends Node { def getAllExpressions: List[Expression] + def getAllPointies: Seq[String] = getAllExpressions.flatMap(_.getPointies) } @@ -304,6 +350,10 @@ case class StructDefinitionStatement(name: String, fields: List[(String, String) override def getAllExpressions: List[Expression] = Nil } +case class UnionDefinitionStatement(name: String, fields: List[(String, String)]) extends DeclarationStatement { + override def getAllExpressions: List[Expression] = Nil +} + case class ArrayDeclarationStatement(name: String, bank: Option[String], length: Option[Expression], diff --git a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala index 12a43376..047e77a4 100644 --- a/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala +++ b/src/main/scala/millfork/node/opt/UnusedLocalVariables.scala @@ -53,6 +53,9 @@ object UnusedLocalVariables extends NodeOptimization { case SumExpression(xs, _) => getAllReadVariables(xs.map(_._2)) case FunctionCallExpression(_, xs) => getAllReadVariables(xs) case IndexedExpression(arr, index) => arr :: getAllReadVariables(List(index)) + case DerefExpression(inner, _, _) => getAllReadVariables(List(inner)) + case DerefDebuggingExpression(inner, _) => getAllReadVariables(List(inner)) + case IndirectFieldExpression(inner, _) => getAllReadVariables(List(inner)) case SeparateBytesExpression(h, l) => getAllReadVariables(List(h, l)) case _ => Nil } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index a71a7398..095000af 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -27,6 +27,12 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri def allowIntelHexAtomsInAssembly: Boolean + val enableDebuggingOptions: Boolean = { + val x = options.flag(CompilationFlag.EnableInternalTestSyntax) + println(s"enableDebuggingOptions = $x") + x + } + def toAst: Parsed[Program] = program.parse(input + "\n\n\n") private val lineStarts: Array[Int] = (0 +: input.zipWithIndex.filter(_._1 == '\n').map(_._2)).toArray @@ -247,12 +253,18 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri def tightMfExpression(allowIntelHex: Boolean): P[Expression] = { val a = if (allowIntelHex) atomWithIntel else atom - P(mfParenExpr(allowIntelHex) | functionCall(allowIntelHex) | mfIndexedExpression | a) // TODO + for { + expression <- mfParenExpr(allowIntelHex) | derefExpression | functionCall(allowIntelHex) | mfIndexedExpression | a + fieldPath <- ("->" ~/ AWS ~/ identifier).rep + } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) } def tightMfExpressionButNotCall(allowIntelHex: Boolean): P[Expression] = { val a = if (allowIntelHex) atomWithIntel else atom - P(mfParenExpr(allowIntelHex) | mfIndexedExpression | a) // TODO + for { + expression <- mfParenExpr(allowIntelHex) | derefExpression | mfIndexedExpression | a + fieldPath <- ("->" ~/ AWS ~/ identifier).rep + } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) } def mfExpression(level: Int, allowIntelHex: Boolean): P[Expression] = { @@ -299,7 +311,10 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri inner.map(x => p(x, 0)) } - def mfLhsExpressionSimple: P[LhsExpression] = mfIndexedExpression | (position() ~ identifier).map { case (p, n) => VariableExpression(n).pos(p) } + def mfLhsExpressionSimple: P[LhsExpression] = for { + expression <- mfIndexedExpression | derefExpression | (position() ~ identifier).map{case (p,n) => VariableExpression(n).pos(p)} ~ HWS + fieldPath <- ("->" ~/ AWS ~/ identifier).rep + } yield if (fieldPath.isEmpty) expression else IndirectFieldExpression(expression, fieldPath.toList) def mfLhsExpression: P[LhsExpression] = for { (p, left) <- position() ~ mfLhsExpressionSimple @@ -321,6 +336,13 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri params <- HWS ~ "(" ~/ AWS ~/ mfExpression(nonStatementLevel, allowIntelHex).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "" } yield FunctionCallExpression(name, params.toList).pos(p) + val derefExpression: P[DerefDebuggingExpression] = for { + p <- position() + if enableDebuggingOptions + yens <- CharsWhileIn(Seq('¥')).! ~/ AWS + inner <- mfParenExpr(false) + } yield DerefDebuggingExpression(inner, yens.length).pos(p) + val expressionStatement: P[Seq[ExecutableStatement]] = mfExpression(0, false).map(x => Seq(ExpressionStatement(x))) val assignmentStatement: P[Seq[ExecutableStatement]] = @@ -474,25 +496,33 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri variants <- enumVariants ~/ Pass } yield Seq(EnumDefinitionStatement(name, variants).pos(p)) - val structField: P[(String, String)] = for { + val compoundTypeField: P[(String, String)] = for { typ <- identifier ~/ HWS name <- identifier ~ HWS } yield typ -> name - val structFields: P[List[(String, String)]] = - ("{" ~/ AWS ~ structField.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) + val compoundTypeFields: P[List[(String, String)]] = + ("{" ~/ AWS ~ compoundTypeField.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) val structDefinition: P[Seq[StructDefinitionStatement]] = for { p <- position() _ <- "struct" ~ !letterOrDigit ~/ SWS ~ position("struct name") - name <- identifier ~/ HWS - _ <- position("struct defintion block") - fields <- structFields ~/ Pass + name <- identifier ~/ HWS + _ <- position("struct defintion block") + fields <- compoundTypeFields ~/ Pass } yield Seq(StructDefinitionStatement(name, fields).pos(p)) + val unionDefinition: P[Seq[UnionDefinitionStatement]] = for { + p <- position() + _ <- "union" ~ !letterOrDigit ~/ SWS ~ position("union name") + name <- identifier ~/ HWS + _ <- position("union defintion block") + fields <- compoundTypeFields ~/ Pass + } yield Seq(UnionDefinitionStatement(name, fields).pos(p)) + val program: Parser[Program] = for { _ <- Start ~/ AWS ~/ Pass - definitions <- (importStatement | arrayDefinition | aliasDefinition | enumDefinition | structDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL) + definitions <- (importStatement | arrayDefinition | aliasDefinition | enumDefinition | structDefinition | unionDefinition | functionDefinition | globalVariableDefinition).rep(sep = EOL) _ <- AWS ~ End } yield Program(definitions.flatten.toList) diff --git a/src/test/scala/millfork/test/DerefSuite.scala b/src/test/scala/millfork/test/DerefSuite.scala new file mode 100644 index 00000000..bc93c956 --- /dev/null +++ b/src/test/scala/millfork/test/DerefSuite.scala @@ -0,0 +1,86 @@ +package millfork.test + +import millfork.Cpu +import millfork.test.emu.EmuUnoptimizedCrossPlatformRun +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class DerefSuite extends FunSuite with Matchers { + test("Basic deref test") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | + | byte output @$c000 + | word output2 @$c010 + | byte crash @$bfff + | void main() { + | pointer p + | p = id(output.addr) + | ¥(p) = 13 + | if (¥(p) != 13) { crash = 1 } + | p = id(output2.addr) + | ¥¥(p) = 600 + | if (¥¥(p) != 600) { crash = 2 } + | } + | + | noinline pointer id(pointer x) { return x } + """.stripMargin){m => + m.readByte(0xc000) should equal (13) + m.readWord(0xc010) should equal (600) + } + } + + test("Byte arithmetic deref test") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | + | byte output1 @$c000 + | byte output2 @$c001 + | byte crash @$bfff + | void main() { + | pointer p + | pointer q + | p = id(output1.addr) + | q = id(output2.addr) + | ¥(p) = 1 + | ¥(q) = 3 + | ¥(p) += 1 + | ¥(p) |= 1 + | ¥(q) = ¥(p) + ¥(q) + | } + | + | noinline pointer id(pointer x) { return x } + """.stripMargin){m => + m.readByte(0xc000) should equal (3) + m.readByte(0xc001) should equal (6) + } + } + + test("Word arithmetic deref test") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | + | word output1 @$c000 + | word output2 @$c002 + | byte crash @$bfff + | void main () { + | pointer p + | pointer q + | p = id(output1.addr) + | q = id(output2.addr) + | ¥¥(p) = 100 + | ¥¥(q) = 300 + | ¥¥(p) = ¥¥(p) + 100 + | ¥¥(p) = ¥¥(p) ^ (300 ^ 200) + | ¥¥(q) = ¥¥(p) + ¥¥(q) + | } + | + | noinline pointer id(pointer x) { return x } + """.stripMargin) { m => + m.readWord(0xc000) should equal(300) + m.readWord(0xc002) should equal(600) + } + } +} diff --git a/src/test/scala/millfork/test/PointerSuite.scala b/src/test/scala/millfork/test/PointerSuite.scala index fe4c26be..12b53eec 100644 --- a/src/test/scala/millfork/test/PointerSuite.scala +++ b/src/test/scala/millfork/test/PointerSuite.scala @@ -1,7 +1,7 @@ package millfork.test import millfork.Cpu -import millfork.test.emu.EmuCrossPlatformBenchmarkRun +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun} import org.scalatest.{FunSuite, Matchers} /** @@ -39,4 +39,117 @@ class PointerSuite extends FunSuite with Matchers { } } + test("Typed byte-targeting pointers") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Sixteen, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | enum e {} + | array(e) output [3] @$c000 + | void main() { + | pointer.e p + | e x + | p = output.pointer + | x = p[0] + | p[0] = e(14) + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(14) + } + } + + test("Typed word-targeting pointers") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Sixteen, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | word output @$c000 + | void main() { + | pointer.word p + | word x + | p = output.pointer + | x = p[0] + | p[0] = 1589 + | } + """.stripMargin) { m => + m.readWord(0xc000) should equal(1589) + } + } + + test("Struct pointers") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)( + """ + | struct point { word x, word y } + | struct pointlist { point head, pointer.pointlist tail } + | array output [5] @$c000 + | array heap[$300] @$c400 + | + | pointer.pointlist this + | pointer heapEnd + | + | pointer.pointlist alloc() { + | pointer.pointlist result + | result = pointer.pointlist(heapEnd) + | heapEnd += sizeof(pointlist) + | return result + | } + | + | void prepend(point p) { + | pointer.pointlist new + | new = alloc() + | // can't copy things larger than 2 bytes right now: + | new->head.x = p.x + | new->head.y = p.y + | new->tail = this + | this = new + | } + | + | void main() { + | heapEnd = heap.addr + | this = pointer.pointlist(0) + | point tmp + | tmp.x = 3 + | tmp.y = 3 + | prepend(tmp) + | tmp.x = 4 + | prepend(tmp) + | tmp.x = 5 + | prepend(tmp) + | + | pointer.pointlist cursor + | byte index + | index = 0 + | cursor = this + | while cursor != 0 { + | output[index] = cursor->head.x.lo + | index += 1 + | cursor = cursor->tail + | } + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(5) + m.readByte(0xc001) should equal(4) + m.readByte(0xc002) should equal(3) + } + } + + test("Pointer optimization") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80)( + """ + | struct s { word a, byte b } + | + | s global_s + | byte output @$c000 + | word output_sink @$c005 + | + | noinline pointer.s init() { + | global_s.b = 44 + | return global_s.pointer + | } + | void main() { + | pointer.s p + | output_sink = p.addr + | p = init() + | output = p->b + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(44) + } + } } diff --git a/src/test/scala/millfork/test/StructSuite.scala b/src/test/scala/millfork/test/StructSuite.scala index dd1c43ec..ceddcc1e 100644 --- a/src/test/scala/millfork/test/StructSuite.scala +++ b/src/test/scala/millfork/test/StructSuite.scala @@ -11,7 +11,7 @@ class StructSuite extends FunSuite with Matchers { test("Basic struct support") { // TODO: 8080 has broken stack operations, fix and uncomment! - EmuUnoptimizedCrossPlatformRun(Cpu.StrictMos, Cpu.Z80/*, Cpu.Intel8080*/)(""" + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80/*, Cpu.Intel8080*/)(""" | struct point { | byte x | byte y @@ -41,7 +41,7 @@ class StructSuite extends FunSuite with Matchers { } test("Nested structs") { - EmuUnoptimizedCrossPlatformRun(Cpu.StrictMos, Cpu.Intel8080)(""" + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Intel8080)(""" | struct inner { word x, word y } | struct s { | word w @@ -68,4 +68,20 @@ class StructSuite extends FunSuite with Matchers { m.readWord(0xc007) should equal(777) } } + + test("Basic union support") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Intel8080)(""" + | struct point { byte x, byte y } + | union point_or_word { point p, word w } + | word output @$c000 + | void main () { + | point_or_word u + | u.p.x = 1 + | u.p.y = 2 + | output = u.w + | } + """.stripMargin) { m => + m.readWord(0xc000) should equal(0x201) + } + } } diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index 2ef8cda5..8663e9f7 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -131,6 +131,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], println(source) val platform = EmuPlatform.get(cpu) val options = CompilationOptions(platform, Map( + CompilationFlag.EnableInternalTestSyntax -> true, CompilationFlag.DecimalMode -> millfork.Cpu.defaultFlags(cpu).contains(CompilationFlag.DecimalMode), CompilationFlag.LenientTextEncoding -> true, CompilationFlag.EmitIllegals -> this.emitIllegals, diff --git a/src/test/scala/millfork/test/emu/EmuZ80Run.scala b/src/test/scala/millfork/test/emu/EmuZ80Run.scala index c9e49178..04785735 100644 --- a/src/test/scala/millfork/test/emu/EmuZ80Run.scala +++ b/src/test/scala/millfork/test/emu/EmuZ80Run.scala @@ -77,6 +77,7 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio println(source) val platform = EmuPlatform.get(cpu) val extraFlags = Map( + CompilationFlag.EnableInternalTestSyntax -> true, CompilationFlag.InlineFunctions -> this.inline, CompilationFlag.OptimizeStdlib -> this.inline, CompilationFlag.OptimizeForSize -> this.optimizeForSize,