From d38405f46750f71d3d395eef8037eb4ee308b88b Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 20 Sep 2019 18:33:41 +0200 Subject: [PATCH] Fix signed constants and word-sbyte subtraction --- CHANGELOG.md | 2 +- examples/atari_lynx/atari_lynx_demo.mfk | 8 +-- .../millfork/compiler/mos/BuiltIns.scala | 41 ++++++++++---- .../compiler/mos/MosExpressionCompiler.scala | 18 +++---- .../compiler/mos/PseudoregisterBuiltIns.scala | 6 +-- .../compiler/z80/Z80ExpressionCompiler.scala | 53 ++++++++++++++++--- src/main/scala/millfork/env/Constant.scala | 29 ++++++++++ src/main/scala/millfork/env/Environment.scala | 45 ++++++++-------- src/main/scala/millfork/env/Thing.scala | 4 +- src/main/scala/millfork/node/Node.scala | 6 ++- src/main/scala/millfork/parser/MfParser.scala | 6 +-- .../scala/millfork/test/ConstantSuite.scala | 12 +++++ .../scala/millfork/test/WordMathSuite.scala | 33 ++++++++++++ 13 files changed, 195 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12d63948..48baef5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ * Added `vectrex`, `msx_br` and `koi7n2` text encodings. -* 6502: Fixed arithmetic promotion bugs for function return values. +* Fixed arithmetic promotion bugs for signed values. * Fixed parsing of `zp_bytes` in platform definitions. diff --git a/examples/atari_lynx/atari_lynx_demo.mfk b/examples/atari_lynx/atari_lynx_demo.mfk index 2cab0199..c9c0884a 100644 --- a/examples/atari_lynx/atari_lynx_demo.mfk +++ b/examples/atari_lynx/atari_lynx_demo.mfk @@ -116,13 +116,7 @@ pointer source } else { // or just move it - // FIXME: word-subbyte doesn't work yet - if input_dy==$ff { - demosp.ypos+=1 - } - if input_dy==$01 { - demosp.ypos-=1 - } + demosp.ypos -= input_dy demosp.xpos += input_dx } } diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 2f7569a3..87e8dc82 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -1355,7 +1355,6 @@ object BuiltIns { } } } - val label = ctx.nextLabel("de") return doDec(targetBytes) case Some(NumericConstant(-1, _)) if canUseIncDec && !subtract => if (ctx.options.flags(CompilationFlag.Emit65CE02Opcodes)) { @@ -1378,7 +1377,9 @@ object BuiltIns { return doDec(targetBytes) case Some(constant) => addendSize = targetSize - Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, constant.subbyte(i)))) + Nil -> List.tabulate(targetSize)(i => List(AssemblyLine.immediate(LDA, + if (i >= addendType.size && constant.isProvablyNegative(addendType)) NumericConstant(-1, 1) else constant.subbyte(i) + ))) case None => addendSize match { case 1 => @@ -1436,17 +1437,17 @@ object BuiltIns { case _ => addend match { case vv: VariableExpression => val source = env.get[Variable](vv.name) - Nil -> List.tabulate(addendSize)(i => AssemblyLine.variable(ctx, LDA, source, i)) + Nil -> List.tabulate(targetSize)(i => AssemblyLine.variable(ctx, LDA, source, i)) case f: FunctionCallExpression => val jsr = MosExpressionCompiler.compile(ctx, addend, None, BranchSpec.None) val result = ctx.env.get[VariableInMemory](f.functionName + ".return") - jsr -> List.tabulate(addendSize)(i => AssemblyLine.variable(ctx, LDA, result, i)) + jsr -> List.tabulate(targetSize)(i => AssemblyLine.variable(ctx, LDA, result, i)) } } } val addendByteRead = addendByteRead0 ++ List.fill((targetSize - addendByteRead0.size) max 0)(List(AssemblyLine.immediate(LDA, 0))) - if (ctx.options.flags(CompilationFlag.EmitNative65816Opcodes)) { + if (ctx.options.flags(CompilationFlag.EmitNative65816Opcodes) && !addendType.isSigned) { (removeTsx(targetBytes), calculateRhs, removeTsx(addendByteRead)) match { case ( List(List(AssemblyLine0(STA, ta1, tl)), List(AssemblyLine0(STA, ta2, th))), @@ -1520,12 +1521,32 @@ object BuiltIns { } buffer ++= targetBytes(i) } else if (subtract) { - if (addendSize < targetSize && addendType.isSigned) { - // TODO: sign extension - ??? + if (i >= addendSize) { + if (addendType.isSigned && !decimal) { + buffer += AssemblyLine.implied(TXA) + buffer ++= staTo(ADC, targetBytes(i)) + } else { + buffer ++= staTo(LDA, targetBytes(i)) + buffer ++= wrapInSedCldIfNeeded(decimal, ldTo(SBC, addendByteRead(i))) + } + } else { + if (addendType.isSigned && i == addendSize - 1 && extendAtLeastOneByte && !decimal) { + val label = ctx.nextLabel("sx") + buffer ++= addendByteRead(i) + buffer += AssemblyLine.immediate(EOR, 0xff) + buffer += AssemblyLine.implied(PHA) + buffer += AssemblyLine.immediate(ORA, 0x7f) + buffer += AssemblyLine.relative(BMI, label) + buffer += AssemblyLine.immediate(LDA, 0) + buffer += AssemblyLine.label(label) + buffer += AssemblyLine.implied(TAX) + buffer += AssemblyLine.implied(PLA) + buffer ++= staTo(ADC, targetBytes(i)) + } else { + buffer ++= staTo(LDA, targetBytes(i)) + buffer ++= wrapInSedCldIfNeeded(decimal, ldTo(SBC, addendByteRead(i))) + } } - buffer ++= staTo(LDA, targetBytes(i)) - buffer ++= wrapInSedCldIfNeeded(decimal, ldTo(SBC, addendByteRead(i))) buffer ++= targetBytes(i) } else { if (i >= addendSize) { diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index 814b5179..9008df91 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -16,7 +16,7 @@ import millfork.output.NoAlignment */ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { - def compileConstant(ctx: CompilationContext, expr: Constant, target: Variable): List[AssemblyLine] = { + def compileConstant(ctx: CompilationContext, expr: Constant, exprType: Type, target: Variable): List[AssemblyLine] = { target match { case RegisterVariable(MosRegister.A, _) => List(AssemblyLine(LDA, Immediate, expr)) case RegisterVariable(MosRegister.AW, _) => @@ -617,7 +617,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case Some(value) => return exprTypeAndVariable.fold(noop) { case (exprType, target) => assertCompatible(exprType, target.typ) - compileConstant(ctx, value, target) + compileConstant(ctx, value, exprType, target) } case _ => } @@ -627,17 +627,17 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case LiteralExpression(value, size) => exprTypeAndVariable.fold(noop) { case (exprType, target) => assertCompatible(exprType, target.typ) - compileConstant(ctx, NumericConstant(value, size), target) + compileConstant(ctx, NumericConstant(value, size), exprType, target) } case GeneratedConstantExpression(value, _) => exprTypeAndVariable.fold(noop) { case (exprType, target) => assertCompatible(exprType, target.typ) - compileConstant(ctx, value, target) + compileConstant(ctx, value, exprType, target) } case VariableExpression(name) => exprTypeAndVariable.fold(noop) { case (exprType, target) => assertCompatible(exprType, target.typ) - env.eval(expr).map(c => compileConstant(ctx, c, target)).getOrElse { + env.eval(expr).map(c => compileConstant(ctx, c, exprType, target)).getOrElse { env.get[TypedThing](name) match { case source: StackOffsetThing => compileStackOffset(ctx, target, source.offset, source.subbyte) case source: VariableInMemory => @@ -902,7 +902,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { } } case source@ConstantThing(_, value, _) => - compileConstant(ctx, value, target) + compileConstant(ctx, value, exprType, target) } } } @@ -1080,7 +1080,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { if (neg) MathOperator.Minus else MathOperator.Plus }, c, v).quickSimplify } - exprTypeAndVariable.map(x => compileConstant(ctx, value.quickSimplify, x._2)).getOrElse(Nil) + exprTypeAndVariable.map(x => compileConstant(ctx, value.quickSimplify, exprType, x._2)).getOrElse(Nil) } else { getSumSize(ctx, params) match { case 1 => @@ -1244,7 +1244,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case Some(c) => exprTypeAndVariable match { case Some((t, v)) => - compileConstant(ctx, c, v) + compileConstant(ctx, c, w, v) case _ => Nil } @@ -1259,7 +1259,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { case Some(c) => exprTypeAndVariable match { case Some((t, v)) => - compileConstant(ctx, c, v) + compileConstant(ctx, c, w, v) case _ => Nil } diff --git a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala index 86645316..94e4478a 100644 --- a/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala @@ -38,7 +38,7 @@ object PseudoregisterBuiltIns { val result = ListBuffer[AssemblyLine]() var hard = Option.empty[List[AssemblyLine]] val niceReads = mutable.ListBuffer[(List[AssemblyLine], List[AssemblyLine])]() - var constant = Constant.Zero + var constant: Constant = NumericConstant(0, 2) var counter = 0 for ((subtract, read) <- reads) { read match { @@ -99,7 +99,7 @@ object PseudoregisterBuiltIns { val (variablePart, constPart) = ctx.env.evalVariableAndConstantSubParts(SumExpression(params, decimal = false)) variablePart match { case None => - return MosExpressionCompiler.compileConstant(ctx, constPart, RegisterVariable(MosRegister.AX, w)) + return MosExpressionCompiler.compileConstant(ctx, constPart, w, RegisterVariable(MosRegister.AX, w)) case Some(v) => val typ = MosExpressionCompiler.getExpressionType(ctx, v) if (typ.size == 1 && !typ.isSigned) { @@ -146,7 +146,7 @@ object PseudoregisterBuiltIns { val (variablePart, constPart) = ctx.env.evalVariableAndConstantSubParts(SumExpression(params, decimal = false)) variablePart match { case None => - return MosExpressionCompiler.compileConstant(ctx, constPart, RegisterVariable(MosRegister.AW, w)) + return MosExpressionCompiler.compileConstant(ctx, constPart, w, RegisterVariable(MosRegister.AW, w)) case Some(v) => } } diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index d92741ff..c354db02 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -354,7 +354,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { import ZRegister._ v.typ.size match { case 0 => ??? - case 1 => loadByte(v.toAddress, target, v.isVolatile) + case 1 => loadByte(ctx, v.toAddress, target, v.isVolatile, v.typ.isSigned) case 2 => target match { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.HL => @@ -548,7 +548,7 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } case i: IndexedExpression => calculateAddressToHL(ctx, i, forWriting = false) match { - case List(ZLine0(LD_16, TwoRegisters(ZRegister.HL, ZRegister.IMM_16), addr)) => loadByte(addr, target, volatile = false) + case List(ZLine0(LD_16, TwoRegisters(ZRegister.HL, ZRegister.IMM_16), addr)) => loadByte(ctx, addr, target, volatile = false, signExtend = false) case code => code ++ loadByteViaHL(target) } case SumExpression(params, decimal) => @@ -1447,16 +1447,53 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { } } - def loadByte(sourceAddr: Constant, target: ZExpressionTarget.Value, volatile: Boolean): List[ZLine] = { + def loadByte(ctx: CompilationContext, sourceAddr: Constant, target: ZExpressionTarget.Value, volatile: Boolean, signExtend: Boolean): List[ZLine] = { + import ZRegister._ + import ZLine.{ld8, ldImm8, ldAbs8, ldImm16} val elidability = if (volatile) Elidability.Volatile else Elidability.Elidable target match { case ZExpressionTarget.NOTHING => Nil case ZExpressionTarget.A => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability)) - case ZExpressionTarget.HL => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0)) - case ZExpressionTarget.BC => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability), ZLine.ld8(ZRegister.C, ZRegister.A), ZLine.ldImm8(ZRegister.B, 0)) - case ZExpressionTarget.DE => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability), ZLine.ld8(ZRegister.E, ZRegister.A), ZLine.ldImm8(ZRegister.D, 0)) - case ZExpressionTarget.EHL => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0), ZLine.ldImm8(ZRegister.E, 0)) - case ZExpressionTarget.DEHL => List(ZLine.ldAbs8(ZRegister.A, sourceAddr, elidability), ZLine.ld8(ZRegister.L, ZRegister.A), ZLine.ldImm8(ZRegister.H, 0), ZLine.ldImm16(ZRegister.DE, 0)) + case ZExpressionTarget.HL => + if (signExtend) { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A)) ++ + signExtendHighestByte(ctx, A, signExtend) ++ + List(ld8(H, A)) + } else { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A), ldImm8(H, 0)) + } + case ZExpressionTarget.BC => + if (signExtend) { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A)) ++ + signExtendHighestByte(ctx, A, signExtend) ++ + List(ld8(H, A)) + } else { + List(ldAbs8(A, sourceAddr, elidability), ld8(C, A), ldImm8(B, 0)) + } + case ZExpressionTarget.DE => + if (signExtend) { + List(ldAbs8(A, sourceAddr, elidability), ld8(E, A)) ++ + signExtendHighestByte(ctx, A, signExtend) ++ + List(ld8(D, A)) + } else { + List(ldAbs8(A, sourceAddr, elidability), ld8(E, A), ldImm8(D, 0)) + } + case ZExpressionTarget.EHL => + if (signExtend) { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A)) ++ + signExtendHighestByte(ctx, A, signExtend) ++ + List(ld8(H, A), ld8(E, A)) + } else { + List(ldAbs8(ZRegister.A, sourceAddr, elidability), ld8(L, A), ldImm8(H, 0), ldImm8(E, 0)) + } + case ZExpressionTarget.DEHL => + if (signExtend) { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A)) ++ + signExtendHighestByte(ctx, A, signExtend) ++ + List(ld8(H, A), ld8(E, A), ld8(D, A)) + } else { + List(ldAbs8(A, sourceAddr, elidability), ld8(L, A), ldImm8(H, 0), ldImm16(DE, 0)) + } } } diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 16681a84..e17741ad 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -25,6 +25,7 @@ sealed trait Constant { def isProvably(value: Int): Boolean = false def isProvablyInRange(startInclusive: Int, endInclusive: Int): Boolean = false def isProvablyNonnegative: Boolean = false + def isProvablyNegative(asType: Type): Boolean = false final def isProvablyGreaterOrEqualThan(other: Int): Boolean = isProvablyGreaterOrEqualThan(Constant(other)) def isProvablyGreaterOrEqualThan(other: Constant): Boolean = other match { case NumericConstant(0, _) => true @@ -114,6 +115,15 @@ sealed trait Constant { def fitsInto(typ: Type): Boolean = true // TODO + def fitInto(typ: Type): Constant = { + // TODO: + typ.size match { + case 1 => loByte + case 2 => subword(0) + case _ => this + } + } + final def succ: Constant = (this + 1).quickSimplify } @@ -123,6 +133,7 @@ case class AssertByte(c: Constant) extends Constant { override def isProvablyZero: Boolean = c.isProvablyZero override def isProvably(i: Int): Boolean = c.isProvably(i) override def isProvablyNonnegative: Boolean = c.isProvablyNonnegative + override def isProvablyNegative(asType: Type): Boolean = c.isProvablyNegative(asType) override def isProvablyInRange(startInclusive: Int, endInclusive: Int): Boolean = c.isProvablyInRange(startInclusive, endInclusive) override def fitsProvablyIntoByte: Boolean = true @@ -204,6 +215,11 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant { override def isProvablyZero: Boolean = value == 0 override def isProvably(i: Int): Boolean = value == i override def isProvablyNonnegative: Boolean = value >= 0 + override def isProvablyNegative(asType: Type): Boolean = { + if (!asType.isSigned) return false + if (asType.size >= 8) return value < 0 + value.&(0x1L.<<(8 * asType.size - 1)) != 0 + } override def fitsProvablyIntoByte: Boolean = requiredSize == 1 override def isProvablyDivisibleBy256: Boolean = (value & 0xff) == 0 @@ -251,6 +267,19 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant { } } } + + override def fitInto(typ: Type): Constant = { + if (typ.size >= 8) { + return NumericConstant(value, typ.size) + } + val actualBits = 1L.<<(8 * typ.size).-(1).&(value) + if (isProvablyNegative(typ)) { + val sx = (-1L).<<(8 * typ.size) + NumericConstant(sx | actualBits, typ.size) + } else { + NumericConstant(actualBits, typ.size) + } + } } case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant { diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index da4ff232..91313f6a 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -573,7 +573,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa None -> Constant.Zero case Some(t: Type) => if (t.isSigned) Some(e) -> Constant.Zero - else variable -> constant + else variable -> constant.fitInto(t) case _ => // dunno what to do Some(e) -> Constant.Zero @@ -768,18 +768,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa maybeGet[Type](name) match { case Some(t: StructType) => if (params.size == t.fields.size) { - sequence(params.map(eval)).map(fields => StructureConstant(t, fields)) + sequence(params.map(eval)).map(fields => StructureConstant(t, fields.zip(t.fields).map{ + case (fieldConst, fieldDesc) => + fieldConst.fitInto(get[Type](fieldDesc.typeName)) + })) } else None case Some(_: UnionType) => None case Some(t) => if (params.size == 1) { eval(params.head).map{ c => - (t.size, t.isSigned) match { - case (1, false) => c.loByte - case (2, false) => c.subword(0) - case _ => c // TODO - } + c.fitInto(t) } } else None case _ => None @@ -892,15 +891,15 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa value = eval(v).getOrElse(errorConstant(s"Enum constant `${stmt.name}.$name` is not a constant", stmt.position)) case _ => } - addThing(ConstantThing(name, value, t), stmt.position) + addThing(ConstantThing(name, value.fitInto(t), t), stmt.position) value += 1 } } def registerStruct(stmt: StructDefinitionStatement): Unit = { stmt.fields.foreach{ f => - if (Environment.invalidFieldNames.contains(f._2)) { - log.error(s"Invalid field name: `${f._2}`", stmt.position) + if (Environment.invalidFieldNames.contains(f.fieldName)) { + log.error(s"Invalid field name: `${f.fieldName}`", stmt.position) } } addThing(StructType(stmt.name, stmt.fields), stmt.position) @@ -908,8 +907,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def registerUnion(stmt: UnionDefinitionStatement): Unit = { stmt.fields.foreach{ f => - if (Environment.invalidFieldNames.contains(f._2)) { - log.error(s"Invalid field name: `${f._2}`", stmt.position) + if (Environment.invalidFieldNames.contains(f.fieldName)) { + log.error(s"Invalid field name: `${f.fieldName}`", stmt.position) } } addThing(UnionType(stmt.name, stmt.fields), stmt.position) @@ -924,7 +923,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa else { val newPath = path + name var sum = 0 - for( (fieldType, _) <- s.fields) { + for( FieldDesc(fieldType, _) <- s.fields) { val fieldSize = getTypeSize(fieldType, newPath) if (fieldSize < 0) return -1 sum += fieldSize @@ -935,7 +934,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } val b = get[Type]("byte") var offset = 0 - for( (fieldType, fieldName) <- s.fields) { + for( FieldDesc(fieldType, fieldName) <- s.fields) { addThing(ConstantThing(s"$name.$fieldName.offset", NumericConstant(offset, 1), b), None) offset += getTypeSize(fieldType, newPath) } @@ -946,7 +945,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa else { val newPath = path + name var max = 0 - for( (fieldType, _) <- s.fields) { + for( FieldDesc(fieldType, _) <- s.fields) { val fieldSize = getTypeSize(fieldType, newPath) if (fieldSize < 0) return -1 max = max max fieldSize @@ -956,7 +955,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error(s"Union `$name` is larger than 255 bytes") } val b = get[Type]("byte") - for ((fieldType, fieldName) <- s.fields) { + for (FieldDesc(fieldType, fieldName) <- s.fields) { addThing(ConstantThing(s"$name.$fieldName.offset", NumericConstant(0, 1), b), None) } max @@ -1358,7 +1357,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa List.fill(tt.size)(LiteralExpression(0, 1)) } else { tt.fields.zip(fieldValues).flatMap { - case ((fieldTypeName, _), expr) => extractStructArrayContents(expr, Some(get[Type](fieldTypeName))) + case (FieldDesc(fieldTypeName, _), expr) => extractStructArrayContents(expr, Some(get[Type](fieldTypeName))) } } case _ => @@ -1392,7 +1391,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa List.fill(tt.size)(LiteralExpression(0, 1)) } else { tt.fields.zip(fieldValues).flatMap { - case ((fieldTypeName, _), expr) => extractStructArrayContents(expr, Some(get[Type](fieldTypeName))) + case (FieldDesc(fieldTypeName, _), expr) => extractStructArrayContents(expr, Some(get[Type](fieldTypeName))) } } case _ => @@ -1649,7 +1648,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (stmt.register) log.error(s"`$name` is a constant and cannot be in a register", position) if (stmt.address.isDefined) log.error(s"`$name` is a constant and cannot have an address", position) if (stmt.initialValue.isEmpty) log.error(s"`$name` is a constant and requires a value", position) - val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(errorConstant(s"`$name` has a non-constant value", position)) + val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(errorConstant(s"`$name` has a non-constant value", position)).fitInto(typ) if (constantValue.requiredSize > typ.size) log.error(s"`$name` is has an invalid value: not in the range of `$typ`", position) addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position) for((suffix, offset, t) <- getSubvariables(typ)) { @@ -1839,7 +1838,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case s: StructType => val builder = new ListBuffer[(String, Int, VariableType)] var offset = 0 - for((typeName, fieldName) <- s.fields) { + for(FieldDesc(typeName, fieldName) <- s.fields) { val typ = get[VariableType](typeName) val suffix = "." + fieldName builder += ((suffix, offset, typ)) @@ -1851,7 +1850,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa builder.toList case s: UnionType => val builder = new ListBuffer[(String, Int, VariableType)] - for((typeName, fieldName) <- s.fields) { + for(FieldDesc(typeName, fieldName) <- s.fields) { val typ = get[VariableType](typeName) val suffix = "." + fieldName builder += ((suffix, 0, typ)) @@ -1937,11 +1936,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa things.values.foreach { case st@StructType(_, fields) => st.mutableFieldsWithTypes = fields.map { - case (tn, name) => get[Type](tn) -> name + case FieldDesc(tn, name) => get[Type](tn) -> name } case ut@UnionType(_, fields) => ut.mutableFieldsWithTypes = fields.map { - case (tn, name) => get[Type](tn) -> name + case FieldDesc(tn, name) => get[Type](tn) -> name } case _ => () } diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 110a408a..b0d0d362 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -112,14 +112,14 @@ case class EnumType(name: String, count: Option[Int]) extends VariableType { sealed trait CompoundVariableType extends VariableType -case class StructType(name: String, fields: List[(String, String)]) extends CompoundVariableType { +case class StructType(name: String, fields: List[FieldDesc]) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 var mutableFieldsWithTypes: List[(Type, String)] = Nil override def isSigned: Boolean = false } -case class UnionType(name: String, fields: List[(String, String)]) extends CompoundVariableType { +case class UnionType(name: String, fields: List[FieldDesc]) extends CompoundVariableType { override def size: Int = mutableSize var mutableSize: Int = -1 var mutableFieldsWithTypes: List[(Type, String)] = Nil diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 38260d1b..1cd11c7d 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -10,6 +10,8 @@ import millfork.output.MemoryAlignment case class Position(moduleName: String, line: Int, column: Int, cursor: Int) +case class FieldDesc(typeName:String, fieldName: String) + sealed trait Node { var position: Option[Position] = None } @@ -455,11 +457,11 @@ case class EnumDefinitionStatement(name: String, variants: List[(String, Option[ override def getAllExpressions: List[Expression] = variants.flatMap(_._2) } -case class StructDefinitionStatement(name: String, fields: List[(String, String)]) extends DeclarationStatement { +case class StructDefinitionStatement(name: String, fields: List[FieldDesc]) extends DeclarationStatement { override def getAllExpressions: List[Expression] = Nil } -case class UnionDefinitionStatement(name: String, fields: List[(String, String)]) extends DeclarationStatement { +case class UnionDefinitionStatement(name: String, fields: List[FieldDesc]) extends DeclarationStatement { override def getAllExpressions: List[Expression] = Nil } diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index b60f9fb0..26bd8465 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -538,12 +538,12 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri variants <- enumVariants ~/ Pass } yield Seq(EnumDefinitionStatement(name, variants).pos(p)) - val compoundTypeField: P[(String, String)] = for { + val compoundTypeField: P[FieldDesc] = for { typ <- identifier ~/ HWS name <- identifier ~ HWS - } yield typ -> name + } yield FieldDesc(typ, name) - val compoundTypeFields: P[List[(String, String)]] = + val compoundTypeFields: P[List[FieldDesc]] = ("{" ~/ AWS ~ compoundTypeField.rep(sep = NoCut(EOLOrComma) ~ !"}" ~/ Pass) ~/ AWS ~/ "}" ~/ Pass).map(_.toList) val structDefinition: P[Seq[StructDefinitionStatement]] = for { diff --git a/src/test/scala/millfork/test/ConstantSuite.scala b/src/test/scala/millfork/test/ConstantSuite.scala index 50b1b958..e9c9f310 100644 --- a/src/test/scala/millfork/test/ConstantSuite.scala +++ b/src/test/scala/millfork/test/ConstantSuite.scala @@ -1,6 +1,7 @@ package millfork.test import millfork.Cpu +import millfork.env.{BasicPlainType, DerivedPlainType, NumericConstant} import millfork.test.emu.EmuUnoptimizedCrossPlatformRun import org.scalatest.{FunSuite, Matchers} @@ -21,4 +22,15 @@ class ConstantSuite extends FunSuite with Matchers { | } """.stripMargin){m => } } + + test("Constants should be negative when needed") { + def signed(size: Int) = DerivedPlainType("", BasicPlainType("", size), isSigned = true, isPointy = false) + NumericConstant(0xff, 1).isProvablyNegative(signed(1)) should be(true) + NumericConstant(0x7f, 1).isProvablyNegative(signed(1)) should be(false) + NumericConstant(0xff, 2).isProvablyNegative(signed(2)) should be(false) + NumericConstant(0xff0f, 2).isProvablyNegative(signed(1)) should be(false) + NumericConstant(-0x4000, 8).isProvablyNegative(signed(1)) should be(false) + NumericConstant(0x7f, 2).isProvablyNegative(signed(2)) should be(false) + NumericConstant(-1, 8).isProvablyNegative(signed(8)) should be(true) + } } \ No newline at end of file diff --git a/src/test/scala/millfork/test/WordMathSuite.scala b/src/test/scala/millfork/test/WordMathSuite.scala index eee093b0..71dc22f1 100644 --- a/src/test/scala/millfork/test/WordMathSuite.scala +++ b/src/test/scala/millfork/test/WordMathSuite.scala @@ -700,4 +700,37 @@ class WordMathSuite extends FunSuite with Matchers with AppendedClues { m.readWord(0xc002) should equal((x % y) & 0xffff) withClue s"= $x %% $y (c002)" } } + + test("Sign extension in subtraction") { + for { + i <- Seq(5324, 6453, 1500) + j <- Seq(0, 1, -1, -3, -7, -128, 127) +// i <- Seq(5324) +// j <- Seq(-1) + } { + EmuUnoptimizedCrossPlatformRun(/*Cpu.Mos, */Cpu.Z80)( + s""" + | word output0 @$$c000 + | word output1 @$$c002 + | word output2 @$$c004 + | void main () { + | sbyte tmp + | output0 = $i + | output2 = $i + | tmp = $j + | memory_barrier() + | output1 = output0 - sbyte(${j&0xff}) + | memory_barrier() + | output0 -= sbyte(${j&0xff}) + | output2 -= tmp + | } + | noinline word id(word w) = w + """. + stripMargin){m => + m.readWord(0xc000) should equal((i - j) & 0xffff) withClue s"= $i - $j (c000)" + m.readWord(0xc002) should equal((i - j) & 0xffff) withClue s"= $i - $j (c002)" + m.readWord(0xc004) should equal((i - j) & 0xffff) withClue s"= $i - $j (c004)" + } + } + } }