diff --git a/src/main/scala/millfork/compiler/mos/ComparisonType.scala b/src/main/scala/millfork/compiler/ComparisonType.scala similarity index 97% rename from src/main/scala/millfork/compiler/mos/ComparisonType.scala rename to src/main/scala/millfork/compiler/ComparisonType.scala index ed894d9c..c2812177 100644 --- a/src/main/scala/millfork/compiler/mos/ComparisonType.scala +++ b/src/main/scala/millfork/compiler/ComparisonType.scala @@ -1,4 +1,4 @@ -package millfork.compiler.mos +package millfork.compiler /** * @author Karol Stasiak diff --git a/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala new file mode 100644 index 00000000..ad509971 --- /dev/null +++ b/src/main/scala/millfork/compiler/z80/Z80Comparisons.scala @@ -0,0 +1,38 @@ +package millfork.compiler.z80 + +import millfork.assembly.z80._ +import millfork.compiler._ +import millfork.node.{Expression, ZRegister} + +/** + * @author Karol Stasiak + */ +object Z80Comparisons { + + import ComparisonType._ + + def compile8BitComparison(ctx: CompilationContext, compType: ComparisonType.Value, l: Expression, r: Expression, branches: BranchSpec): List[ZLine] = { + compType match { + case GreaterUnsigned | LessOrEqualUnsigned | GreaterSigned | LessOrEqualSigned => + return compile8BitComparison(ctx, ComparisonType.flip(compType), r, l, branches) + case _ => () + } + val calculateFlags = + Z80ExpressionCompiler.compileToA(ctx, r) ++ + List(ZLine.ld8(ZRegister.E, ZRegister.A)) ++ + Z80ExpressionCompiler.stashDEIfChanged(Z80ExpressionCompiler.compileToA(ctx, l)) ++ + List(ZLine.register(ZOpcode.CP, ZRegister.E)) + val jump = (compType, branches) match { + case (Equal, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (Equal, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.Z)) + case (NotEqual, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.Z)) + case (LessUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case (LessUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfTrue(label)) => ZLine.jump(label, IfFlagClear(ZFlag.C)) + case (GreaterOrEqualUnsigned, BranchIfFalse(label)) => ZLine.jump(label, IfFlagSet(ZFlag.C)) + case _ => ??? + } + calculateFlags :+ jump + } +} diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index a536f4cf..a10e9e95 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -216,22 +216,52 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { ??? case "<" => val (size, signed) = assertComparison(ctx, params) - ??? + compileTransitiveRelation(ctx, "<", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + case _ => ??? + } + } case ">=" => val (size, signed) = assertComparison(ctx, params) - ??? + compileTransitiveRelation(ctx, ">=", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + case _ => ??? + } + } case ">" => val (size, signed) = assertComparison(ctx, params) - ??? + compileTransitiveRelation(ctx, ">", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + case _ => ??? + } + } case "<=" => val (size, signed) = assertComparison(ctx, params) - ??? + compileTransitiveRelation(ctx, "<=", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + case _ => ??? + } + } case "==" => val size = params.map(p => getExpressionType(ctx, p).size).max - ??? + compileTransitiveRelation(ctx, "==", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, ComparisonType.Equal, l, r, branches) + case _ => ??? + } + } case "!=" => val (l, r, size) = assertBinary(ctx, params) - ??? + compileTransitiveRelation(ctx, "!=", params, target, branches) { (l, r) => + size match { + case 1 => Z80Comparisons.compile8BitComparison(ctx, ComparisonType.NotEqual, l, r, branches) + case _ => ??? + } + } case "+=" => val (l, r, size) = assertAssignmentLike(ctx, params) size match { @@ -495,4 +525,39 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { case SeparateBytesExpression(hi, lo) => ??? } } + + private def compileTransitiveRelation(ctx: CompilationContext, + operator: String, + params: List[Expression], + target: ZExpressionTarget.Value, + branches: BranchSpec)(binary: (Expression, Expression) => List[ZLine]): List[ZLine] = { + params match { + case List(l, r) => binary(l, r) + case List(_) | Nil => + ErrorReporting.fatal("") + case _ => + params.tail.init.foreach { e => + if (ctx.env.eval(e).isEmpty) e match { + case VariableExpression(_) => + case LiteralExpression(_, _) => + case IndexedExpression(_, VariableExpression(_)) => + case IndexedExpression(_, LiteralExpression(_, _)) => + case IndexedExpression(_, SumExpression(List( + (_, LiteralExpression(_, _)), + (false, VariableExpression(_)) + ), false)) => + case IndexedExpression(_, SumExpression(List( + (false, VariableExpression(_)), + (_, LiteralExpression(_, _)) + ), false)) => + case _ => + ErrorReporting.warn("A complex expression may be evaluated multiple times", ctx.options, e.position) + } + } + val conjunction = params.init.zip(params.tail).map { + case (l, r) => FunctionCallExpression(operator, List(l, r)) + }.reduceLeft((a, b) => FunctionCallExpression("&&", List(a, b))) + compile(ctx, conjunction, target, branches) + } + } } diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index 852d325b..1507c3a4 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -1,6 +1,7 @@ package millfork.test -import millfork.test.emu.{EmuBenchmarkRun, EmuSuperOptimizedRun, EmuUltraBenchmarkRun} +import millfork.CpuFamily +import millfork.test.emu.{EmuBenchmarkRun, EmuCrossPlatformBenchmarkRun, EmuSuperOptimizedRun, EmuUltraBenchmarkRun} import org.scalatest.{FunSuite, Matchers} /** @@ -9,7 +10,7 @@ import org.scalatest.{FunSuite, Matchers} class ComparisonSuite extends FunSuite with Matchers { test("Equality and inequality") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( """ | byte output @$c000 | void main () { @@ -27,7 +28,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Less") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( """ | byte output @$c000 | void main () { @@ -40,7 +41,7 @@ class ComparisonSuite extends FunSuite with Matchers { } test("Compare to zero") { - EmuBenchmarkRun( + EmuCrossPlatformBenchmarkRun(CpuFamily.M6502, CpuFamily.I80)( """ | byte output @$c000 | void main () {