diff --git a/CHANGELOG.md b/CHANGELOG.md index d864bebe..c6e19a6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ * Fixed a bug with variable overlapping (#11). +* 8080: Fixed and optimized 16-bit comparisons. + * 8080: Optimized some library functions. * Optimized certain byte comparisons. diff --git a/docs/lang/operators.md b/docs/lang/operators.md index 4bbe4946..4cf2d713 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -161,7 +161,8 @@ Note you cannot mix those operators, so `a <= b < c` is not valid. * `>`, `<`, `<=`, `>=`: inequality `byte > byte` -`simple word > simple word` +`simple word > word` +`word > simple word` `simple long > simple long` Currently, `>`, `<`, `<=`, `>=` operators perform signed comparison diff --git a/src/main/scala/millfork/compiler/mos/BuiltIns.scala b/src/main/scala/millfork/compiler/mos/BuiltIns.scala index 09fb57ee..75d85d4c 100644 --- a/src/main/scala/millfork/compiler/mos/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/mos/BuiltIns.scala @@ -724,7 +724,7 @@ object BuiltIns { val rc = MosExpressionCompiler.compileToAX(ctx, rhs) (lc, rc, effectiveComparisonType) match { case ( - List(lcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), lch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage, _)), + List(lcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), lch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), _, ComparisonType.NotEqual ) => @@ -735,7 +735,7 @@ object BuiltIns { AssemblyLine.relative(BNE, Label(x))) case ( _, - List(rcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), rch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage, _)), + List(rcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), rch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), ComparisonType.NotEqual ) => return lc ++ List( @@ -744,7 +744,7 @@ object BuiltIns { rch.copy(opcode = CPX), AssemblyLine.relative(BNE, Label(x))) case ( - List(lcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), lch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage, _)), + List(lcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), lch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), _, ComparisonType.Equal ) => @@ -757,7 +757,7 @@ object BuiltIns { AssemblyLine.label(skip)) case ( _, - List(rcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), rch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage, _)), + List(rcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), rch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), ComparisonType.Equal ) => val skip = ctx.nextLabel("cp") @@ -767,6 +767,42 @@ object BuiltIns { rch.copy(opcode = CPX), AssemblyLine.relative(BEQ, Label(x)), AssemblyLine.label(skip)) + case ( + List(lcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), lch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), + List(rcl@AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), rch@AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), + _ + ) => + (Nil, + List(lch.copy(opcode = CMP)), + List(lcl.copy(opcode = CMP)), + List(rch.copy(opcode = CMP)), + List(rcl.copy(opcode = CMP))) + case ( + _, + List(AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), + _ + ) => + if (ctx.options.zpRegisterSize < 2) { + ctx.log.error("Too complex expressions in comparison", lhs.position.orElse(rhs.position)) + (Nil, Nil, Nil, Nil, Nil) + } else { + val reg = ctx.env.get[ThingInMemory]("__reg.loword") + return lc ++ List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) ++ compileWordComparison( + ctx, effectiveComparisonType, VariableExpression("__reg.loword").pos(lhs.position), rhs, BranchIfTrue(x)) + } + case ( + List(AssemblyLine0(LDA, Absolute | Immediate | ZeroPage | LongAbsolute, _), AssemblyLine0(LDX, Absolute | Immediate | ZeroPage | LongAbsolute, _)), + _, + _ + ) => + if (ctx.options.zpRegisterSize < 2) { + ctx.log.error("Too complex expressions in comparison", lhs.position.orElse(rhs.position)) + (Nil, Nil, Nil, Nil, Nil) + } else { + val reg = ctx.env.get[ThingInMemory]("__reg.loword") + return rc ++ List(AssemblyLine.zeropage(STA, reg), AssemblyLine.zeropage(STX, reg, 1)) ++ compileWordComparison( + ctx, effectiveComparisonType, lhs, VariableExpression("__reg.loword").pos(rhs.position), BranchIfTrue(x)) + } case _ => // TODO comparing expressions ctx.log.error("Too complex expressions in comparison", lhs.position.orElse(rhs.position)) diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala index f55d4548..02682e4e 100644 --- a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -192,14 +192,21 @@ object Z80Comparisons { case _ => } - val (calculated, useBC) = if (calculateLeft.exists(Z80ExpressionCompiler.changesBC)) { - if (calculateLeft.exists(Z80ExpressionCompiler.changesDE)) { - calculateRight ++ List(ZLine.register(ZOpcode.PUSH, ZRegister.HL)) ++ Z80ExpressionCompiler.fixTsx(ctx, calculateLeft) ++ List(ZLine.register(ZOpcode.POP, ZRegister.BC)) -> false - } else { - calculateRight ++ List(ZLine.ld8(ZRegister.D, ZRegister.H), ZLine.ld8(ZRegister.E, ZRegister.L)) ++ calculateLeft -> false - } - } else { - calculateRight ++ List(ZLine.ld8(ZRegister.B, ZRegister.H), ZLine.ld8(ZRegister.C, ZRegister.L)) ++ calculateLeft -> true + val (calculated, useBC) = calculateRight match { + case List(ZLine0(LD_16, TwoRegisters(ZRegister.HL, ZRegister.IMM_16), c)) => + (calculateLeft :+ ZLine.ldImm16(ZRegister.BC, c)) -> true + case List(ZLine0(LD_16, TwoRegisters(ZRegister.HL, ZRegister.MEM_ABS_16), c)) if ctx.options.flag(CompilationFlag.EmitZ80Opcodes) => + (calculateLeft :+ ZLine.ldAbs16(ZRegister.BC, c)) -> true + case _ => + if (calculateLeft.exists(Z80ExpressionCompiler.changesBC)) { + if (calculateLeft.exists(Z80ExpressionCompiler.changesDE)) { + calculateRight ++ List(ZLine.register(ZOpcode.PUSH, ZRegister.HL)) ++ Z80ExpressionCompiler.fixTsx(ctx, calculateLeft) ++ List(ZLine.register(ZOpcode.POP, ZRegister.BC)) -> true + } else { + calculateRight ++ List(ZLine.ld8(ZRegister.D, ZRegister.H), ZLine.ld8(ZRegister.E, ZRegister.L)) ++ calculateLeft -> false + } + } else { + calculateRight ++ List(ZLine.ld8(ZRegister.B, ZRegister.H), ZLine.ld8(ZRegister.C, ZRegister.L)) ++ calculateLeft -> true + } } val (effectiveCompType, label) = branches match { case BranchIfFalse(la) => ComparisonType.negate(compType) -> la diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index 12a46f1b..8598a690 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -581,4 +581,52 @@ class ComparisonSuite extends FunSuite with Matchers { } } + test("Complex word comparisons") { + val code = + """ + | byte output @$c000 + | struct st { word x } + | st obj + | noinline word f() = 400 + | void main() { + | pointer.st p + | p = obj.pointer + | p->x = 400 + | output = 0 + | + | if p->x == 400 { output += 1 } // ↑ + | if p->x != 400 { output -= 1 } + | if p->x >= 400 { output += 1 } // ↑ + | if p->x <= 400 { output += 1 } // ↑ + | if p->x > 400 { output -= 1 } + | if p->x < 400 { output -= 1 } + | + | if 400 == p->x { output += 1 } // ↑ + | if 400 != p->x { output -= 1 } + | if 400 <= p->x { output += 1 } // ↑ + | if 400 >= p->x { output += 1 } // ↑ + | if 400 < p->x { output -= 1 } + | if 400 > p->x { output -= 1 } + | + | if f() == 400 { output += 1 } // ↑ + | if f() != 400 { output -= 1 } + | if f() >= 400 { output += 1 } // ↑ + | if f() <= 400 { output += 1 } // ↑ + | if f() > 400 { output -= 1 } + | if f() < 400 { output -= 1 } + | + | if 400 == f() { output += 1 } // ↑ + | if 400 != f() { output -= 1 } + | if 400 <= f() { output += 1 } // ↑ + | if 400 >= f() { output += 1 } // ↑ + | if 400 < f() { output -= 1 } + | if 400 > f() { output -= 1 } + | + | } + |""".stripMargin + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80)(code) { m => + m.readByte(0xc000) should equal(code.count(_ == '↑')) + } + } + }