diff --git a/CHANGELOG.md b/CHANGELOG.md index f9c67897..542e9528 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ * Preliminary Atari 2600, BBC Micro and LUnix support. -* Added array initialization syntax with `for` (not yet finalized). +* Added array initialization syntax with `for`. * Added multiple new text codecs. @@ -20,6 +20,8 @@ * Fixed broken `downto` loops. +* Fixed broken comparisons between variables of different sizes. + * Fixed several other bugs. * Other improvements. diff --git a/src/main/scala/millfork/assembly/mos/AssemblyLine.scala b/src/main/scala/millfork/assembly/mos/AssemblyLine.scala index 3b29c6d4..33339e30 100644 --- a/src/main/scala/millfork/assembly/mos/AssemblyLine.scala +++ b/src/main/scala/millfork/assembly/mos/AssemblyLine.scala @@ -369,10 +369,10 @@ object AssemblyLine { private val opcodesForNopVariableOperation = Set(STA, SAX, STX, STY, STZ) private val opcodesForZeroedVariableOperation = Set(ADC, EOR, ORA, AND, SBC, CMP, CPX, CPY) - private val opcodesForZeroedOrSignExtendedVariableOperation = Set(LDA, LDX, LDY) + private val opcodesForZeroedOrSignExtendedVariableOperation = Set(LDA, LDX, LDY, LDZ) def variable(ctx: CompilationContext, opcode: Opcode.Value, variable: Variable, offset: Int = 0): List[AssemblyLine] = - if (offset > variable.typ.size) { + if (offset >= variable.typ.size) { if (opcodesForNopVariableOperation(opcode)) { Nil } else if (opcodesForZeroedVariableOperation(opcode)) { @@ -381,11 +381,16 @@ object AssemblyLine { } else if (opcodesForZeroedOrSignExtendedVariableOperation(opcode)) { if (variable.typ.isSigned) { val label = MosCompiler.nextLabel("sx") - AssemblyLine.variable(ctx, opcode, variable, variable.typ.size - 1) ++ List( + AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) ++ List( AssemblyLine.immediate(ORA, 0x7f), AssemblyLine.relative(BMI, label), AssemblyLine.immediate(LDA, 0), - AssemblyLine.label(label)) + AssemblyLine.label(label)) ++ (opcode match { + case LDA => Nil + case LDX | LAX => List(AssemblyLine.implied(TAX)) + case LDY => List(AssemblyLine.implied(TAY)) + case LDZ => List(AssemblyLine.implied(TAZ)) + }) } else { List(AssemblyLine.immediate(opcode, 0)) } diff --git a/src/main/scala/millfork/assembly/mos/opt/AlwaysGoodOptimizations.scala b/src/main/scala/millfork/assembly/mos/opt/AlwaysGoodOptimizations.scala index 537db6e2..3e05ca27 100644 --- a/src/main/scala/millfork/assembly/mos/opt/AlwaysGoodOptimizations.scala +++ b/src/main/scala/millfork/assembly/mos/opt/AlwaysGoodOptimizations.scala @@ -1791,6 +1791,62 @@ object AlwaysGoodOptimizations { (Elidable & HasSourceOfNZ(State.X) & HasOpcode(CPX) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init), (Elidable & HasSourceOfNZ(State.Y) & HasOpcode(CPY) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init), (Elidable & HasSourceOfNZ(State.IZ) & HasOpcode(CPZ) & HasImmediate(0) & DoesntMatterWhatItDoesWith(State.C)) ~~> (_.init), + + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE)) + }, + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code => + List(code.head.copy(opcode = LDX), code(1).copy(opcode = BNE)) + }, + (Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code => + List(code.head.copy(opcode = LDX), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code => + List(code.head.copy(opcode = LDY), code(1).copy(opcode = BNE)) + }, + (Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code => + List(code.head.copy(opcode = LDY), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code => + List(code.head.copy(opcode = LDX), code(1).copy(opcode = BNE)) + }, + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.X)) ~~> {code => + List(code.head.copy(opcode = LDX), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code => + List(code.head.copy(opcode = LDY), code(1).copy(opcode = BNE)) + }, + (Elidable & HasA(0) & HasOpcode(CMP)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.Y)) ~~> {code => + List(code.head.copy(opcode = LDY), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE)) + }, + (Elidable & HasX(0) & HasOpcode(CPX)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BNE)) + }, + (Elidable & HasY(0) & HasOpcode(CPY)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.A)) ~~> {code => + List(code.head.copy(opcode = LDA), code(1).copy(opcode = BEQ)) + }, + + (Elidable & HasZ(0) & HasOpcode(CPZ)) ~ (HasOpcode(BCC) & DoesntMatterWhatItDoesWith(State.C, State.N, State.IZ)) ~~> {code => + List(code.head.copy(opcode = LDZ), code(1).copy(opcode = BNE)) + }, + (Elidable & HasZ(0) & HasOpcode(CPZ)) ~ (HasOpcode(BCS) & DoesntMatterWhatItDoesWith(State.C, State.N, State.IZ)) ~~> {code => + List(code.head.copy(opcode = LDZ), code(1).copy(opcode = BEQ)) + }, ) private def remapZ2N(line: AssemblyLine) = line.opcode match { diff --git a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala index 64da71e5..8febbc46 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementCompiler.scala @@ -107,7 +107,25 @@ abstract class AbstractStatementCompiler[T <: AbstractCode] { val increment = ExpressionStatement(FunctionCallExpression("+=", List(vex, one))) val decrement = ExpressionStatement(FunctionCallExpression("-=", List(vex, one))) val names = Set("", "for", f.variable) - (f.direction, ctx.env.eval(f.start), ctx.env.eval(f.end)) match { + + val startEvaluated = ctx.env.eval(f.start) + val endEvaluated = ctx.env.eval(f.end) + ctx.env.maybeGet[Variable](f.variable).foreach{ v=> + startEvaluated.foreach(value => if (!value.quickSimplify.fitsInto(v.typ)) { + ErrorReporting.error(s"Variable `${f.variable}` is too small to hold the initial value in the for loop", f.position) + }) + endEvaluated.foreach { value => + val max = f.direction match { + case ForDirection.To | ForDirection.ParallelTo | ForDirection.DownTo => value + case ForDirection.Until | ForDirection.ParallelUntil => value - 1 + case _ => Constant.Zero + } + if (!max.quickSimplify.fitsInto(v.typ)) { + ErrorReporting.error(s"Variable `${f.variable}` is too small to hold the final value in the for loop", f.position) + } + } + } + (f.direction, startEvaluated, endEvaluated) match { case (ForDirection.Until | ForDirection.ParallelUntil, Some(NumericConstant(s, ssize)), Some(NumericConstant(e, _))) if s == e - 1 => val end = nextLabel("of") diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 890a9223..1d9ecf08 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -32,6 +32,8 @@ object BuiltIns { def staTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == STA) x.copy(opcode = op) else x) + def cmpTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == CMP) x.copy(opcode = op) else x) + def ldTo(op: Opcode.Value, l: List[AssemblyLine]): List[AssemblyLine] = l.map(x => if (x.opcode == LDA || x.opcode == LDX || x.opcode == LDY) x.copy(opcode = op) else x) def simpleOperation(opcode: Opcode.Value, ctx: CompilationContext, source: Expression, indexChoice: IndexChoice.Value, preserveA: Boolean, commutative: Boolean, decimal: Boolean = false): List[AssemblyLine] = { @@ -463,17 +465,17 @@ object BuiltIns { return compileWordComparison(ctx, ComparisonType.flip(compType), rhs, lhs, branches) case (v: VariableExpression, None, _, Some(rc)) => val lva = env.get[VariableInMemory](v.name) - (AssemblyLine.variable(ctx, STA, lva, 1), - AssemblyLine.variable(ctx, STA, lva, 0), - List(AssemblyLine.immediate(STA, rc.hiByte.quickSimplify)), - List(AssemblyLine.immediate(STA, rc.loByte.quickSimplify))) + (AssemblyLine.variable(ctx, CMP, lva, 1), + AssemblyLine.variable(ctx, CMP, lva, 0), + List(AssemblyLine.immediate(CMP, rc.hiByte.quickSimplify)), + List(AssemblyLine.immediate(CMP, rc.loByte.quickSimplify))) case (lv: VariableExpression, None, rv: VariableExpression, None) => val lva = env.get[VariableInMemory](lv.name) val rva = env.get[VariableInMemory](rv.name) - (AssemblyLine.variable(ctx, STA, lva, 1), - AssemblyLine.variable(ctx, STA, lva, 0), - AssemblyLine.variable(ctx, STA, rva, 1), - AssemblyLine.variable(ctx, STA, rva, 0)) + (AssemblyLine.variable(ctx, CMP, lva, 1), + AssemblyLine.variable(ctx, CMP, lva, 0), + AssemblyLine.variable(ctx, CMP, rva, 1), + AssemblyLine.variable(ctx, CMP, rva, 0)) case _ => // TODO comparing expressions ErrorReporting.error("Too complex expressions in comparison", lhs.position) @@ -484,9 +486,9 @@ object BuiltIns { val compactEqualityComparison = if (ctx.options.flag(CompilationFlag.OptimizeForSpeed)) { None } else if (lType.size == 1 && !lType.isSigned) { - Some(staTo(LDA, ll) ++ staTo(EOR, rl) ++ staTo(ORA, rh)) + Some(cmpTo(LDA, ll) ++ cmpTo(EOR, rl) ++ cmpTo(ORA, rh)) } else if (rType.size == 1 && !rType.isSigned) { - Some(staTo(LDA, rl) ++ staTo(EOR, ll) ++ staTo(ORA, lh)) + Some(cmpTo(LDA, rl) ++ cmpTo(EOR, ll) ++ cmpTo(ORA, lh)) } else { None } @@ -496,11 +498,11 @@ object BuiltIns { case Some(code) => code :+ AssemblyLine.relative(BEQ, Label(x)) case None => val innerLabel = MosCompiler.nextLabel("cp") - staTo(LDA, ll) ++ - staTo(CMP, rl) ++ + cmpTo(LDA, ll) ++ + cmpTo(CMP, rl) ++ List(AssemblyLine.relative(BNE, innerLabel)) ++ - staTo(LDA, lh) ++ - staTo(CMP, rh) ++ + cmpTo(LDA, lh) ++ + cmpTo(CMP, rh) ++ List( AssemblyLine.relative(BEQ, Label(x)), AssemblyLine.label(innerLabel)) @@ -510,57 +512,57 @@ object BuiltIns { compactEqualityComparison match { case Some(code) => code :+ AssemblyLine.relative(BNE, Label(x)) case None => - staTo(LDA, ll) ++ - staTo(CMP, rl) ++ + cmpTo(LDA, ll) ++ + cmpTo(CMP, rl) ++ List(AssemblyLine.relative(BNE, Label(x))) ++ - staTo(LDA, lh) ++ - staTo(CMP, rh) ++ + cmpTo(LDA, lh) ++ + cmpTo(CMP, rh) ++ List(AssemblyLine.relative(BNE, Label(x))) } case ComparisonType.LessUnsigned => val innerLabel = MosCompiler.nextLabel("cp") - staTo(LDA, lh) ++ - staTo(CMP, rh) ++ + cmpTo(LDA, lh) ++ + cmpTo(CMP, rh) ++ List( AssemblyLine.relative(BCC, Label(x)), AssemblyLine.relative(BNE, innerLabel)) ++ - staTo(LDA, ll) ++ - staTo(CMP, rl) ++ + cmpTo(LDA, ll) ++ + cmpTo(CMP, rl) ++ List( AssemblyLine.relative(BCC, Label(x)), AssemblyLine.label(innerLabel)) case ComparisonType.LessOrEqualUnsigned => val innerLabel = MosCompiler.nextLabel("cp") - staTo(LDA, rh) ++ - staTo(CMP, lh) ++ + cmpTo(LDA, rh) ++ + cmpTo(CMP, lh) ++ List(AssemblyLine.relative(BCC, innerLabel), AssemblyLine.relative(BNE, x)) ++ - staTo(LDA, rl) ++ - staTo(CMP, ll) ++ + cmpTo(LDA, rl) ++ + cmpTo(CMP, ll) ++ List(AssemblyLine.relative(BCS, x), AssemblyLine.label(innerLabel)) case ComparisonType.GreaterUnsigned => val innerLabel = MosCompiler.nextLabel("cp") - staTo(LDA, rh) ++ - staTo(CMP, lh) ++ + cmpTo(LDA, rh) ++ + cmpTo(CMP, lh) ++ List(AssemblyLine.relative(BCC, Label(x)), AssemblyLine.relative(BNE, innerLabel)) ++ - staTo(LDA, rl) ++ - staTo(CMP, ll) ++ + cmpTo(LDA, rl) ++ + cmpTo(CMP, ll) ++ List(AssemblyLine.relative(BCC, Label(x)), AssemblyLine.label(innerLabel)) case ComparisonType.GreaterOrEqualUnsigned => val innerLabel = MosCompiler.nextLabel("cp") - staTo(LDA, lh) ++ - staTo(CMP, rh) ++ + cmpTo(LDA, lh) ++ + cmpTo(CMP, rh) ++ List(AssemblyLine.relative(BCC, innerLabel), AssemblyLine.relative(BNE, x)) ++ - staTo(LDA, ll) ++ - staTo(CMP, rl) ++ + cmpTo(LDA, ll) ++ + cmpTo(CMP, rl) ++ List(AssemblyLine.relative(BCS, x), AssemblyLine.label(innerLabel)) @@ -619,23 +621,23 @@ object BuiltIns { case ComparisonType.Equal => val innerLabel = MosCompiler.nextLabel("cp") val bytewise = l.zip(r).map{ - case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR) + case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(CMP, cmpR) } bytewise.init.flatMap(b => b :+ AssemblyLine.relative(BNE, innerLabel)) ++ bytewise.last ++List( AssemblyLine.relative(BEQ, Label(label)), AssemblyLine.label(innerLabel)) case ComparisonType.NotEqual => l.zip(r).flatMap { - case (staL, staR) => staTo(LDA, staL) ++ staTo(CMP, staR) :+ AssemblyLine.relative(BNE, label) + case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(CMP, cmpR) :+ AssemblyLine.relative(BNE, label) } case ComparisonType.LessUnsigned => val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{ - case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR) + case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(SBC, cmpR) } calculateCarry ++ List(AssemblyLine.relative(BCC, Label(label))) case ComparisonType.GreaterOrEqualUnsigned => val calculateCarry = AssemblyLine.implied(SEC) :: l.zip(r).flatMap{ - case (staL, staR) => staTo(LDA, staL) ++ staTo(SBC, staR) + case (cmpL, cmpR) => cmpTo(LDA, cmpL) ++ cmpTo(SBC, cmpR) } calculateCarry ++ List(AssemblyLine.relative(BCS, Label(label))) case ComparisonType.GreaterUnsigned | ComparisonType.LessOrEqualUnsigned => @@ -1180,27 +1182,27 @@ object BuiltIns { val env = ctx.env env.eval(expr) match { case Some(c) => - List.tabulate(size) { i => List(AssemblyLine.immediate(STA, c.subbyte(i))) } + List.tabulate(size) { i => List(AssemblyLine.immediate(CMP, c.subbyte(i))) } case None => expr match { case v: VariableExpression => val variable = env.get[Variable](v.name) List.tabulate(size) { i => if (i < variable.typ.size) { - AssemblyLine.variable(ctx, STA, variable, i) + AssemblyLine.variable(ctx, CMP, variable, i) } else if (variable.typ.isSigned) { val label = MosCompiler.nextLabel("sx") - AssemblyLine.variable(ctx, STA, variable, variable.typ.size - 1) ++ List( + AssemblyLine.variable(ctx, LDA, variable, variable.typ.size - 1) ++ List( AssemblyLine.immediate(ORA, 0x7F), AssemblyLine.relative(BMI, label), - AssemblyLine.immediate(STA, 0), + AssemblyLine.immediate(CMP, 0), AssemblyLine.label(label)) - } else List(AssemblyLine.immediate(STA, 0)) + } else List(AssemblyLine.immediate(CMP, 0)) } case expr@IndexedExpression(variable, index) => List.tabulate(size) { i => if (i == 0) MosExpressionCompiler.compileByteStorage(ctx, MosRegister.A, expr) - else List(AssemblyLine.immediate(STA, 0)) + else List(AssemblyLine.immediate(CMP, 0)) } case SeparateBytesExpression(h: LhsExpression, l: LhsExpression) => if (simplicity(ctx.env, h) < 'J' || simplicity(ctx.env, l) < 'J') { @@ -1210,7 +1212,7 @@ object BuiltIns { List.tabulate(size) { i => if (i == 0) getStorageForEachByte(ctx, l).head else if (i == 1) MosExpressionCompiler.preserveRegisterIfNeeded(ctx, MosRegister.A, getStorageForEachByte(ctx, h).head) - else List(AssemblyLine.immediate(STA, 0)) + else List(AssemblyLine.immediate(CMP, 0)) } case _ => ??? diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 7335bbb7..e7077cc0 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -73,6 +73,8 @@ sealed trait Constant { def quickSimplify: Constant = this def isRelatedTo(v: Thing): Boolean + + def fitsInto(typ: Type): Boolean = true // TODO } case class AssertByte(c: Constant) extends Constant { @@ -83,6 +85,8 @@ case class AssertByte(c: Constant) extends Constant { override def isRelatedTo(v: Thing): Boolean = c.isRelatedTo(v) override def quickSimplify: Constant = AssertByte(c.quickSimplify) + + override def fitsInto(typ: Type): Boolean = true } case class UnexpandedConstant(name: String, requiredSize: Int) extends Constant { @@ -116,6 +120,26 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant { override def toString: String = if (value > 9) value.formatted("$%X") else value.toString override def isRelatedTo(v: Thing): Boolean = false + + override def fitsInto(typ: Type): Boolean = { + if (typ.isSigned) { + typ.size match { + case 1 => value == value.toByte + case 2 => value == value.toShort + case 3 => value == ((value.toInt << 8) >> 8) + case 4 => value == value.toInt + case _ => true + } + } else { + typ.size match { + case 1 => value == (value & 0xff) + case 2 => value == (value & 0xffff) + case 3 => value == (value & 0xffffff) + case 4 => value == (value & 0xffffffffL) + case _ => true + } + } + } } case class MemoryAddressConstant(var thing: ThingInMemory) extends Constant { diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index efbcabdd..852d325b 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -340,4 +340,22 @@ class ComparisonSuite extends FunSuite with Matchers { """.stripMargin EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(src.count(_ == '+'))) } + + test("Mixed type comparison") { + val src = + """ + | byte output @$c000 + | byte x @$c002 + | byte y @$c003 + | void main () { + | word z + | output = 0 + | z = $100 + | x = 4 + | y = 1 + | if x < z { output += 1 } + | } + """.stripMargin + EmuBenchmarkRun(src)(_.readByte(0xc000) should equal(1)) + } }