From 504cb56ee7e8e8853453cbd84129a37eb2ff9265 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Mon, 1 Jan 2018 22:37:23 +0100 Subject: [PATCH] Multiple arguments for some relative operators --- .../scala/millfork/compiler/MfCompiler.scala | 74 ++++++++++++------- .../scala/millfork/test/ComparisonSuite.scala | 32 ++++++++ 2 files changed, 81 insertions(+), 25 deletions(-) diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala index ffcc7401..bdf4923c 100644 --- a/src/main/scala/millfork/compiler/MfCompiler.scala +++ b/src/main/scala/millfork/compiler/MfCompiler.scala @@ -405,15 +405,12 @@ object MlCompiler { } } - def assertComparison(ctx: CompilationContext, params: List[Expression]): (Expression, Expression, Int, Boolean) = { - if (params.length != 2) { - ErrorReporting.fatal("sfgdgfsd", None) - } + def assertComparison(ctx: CompilationContext, params: List[Expression]): (Int, Boolean) = { (params.head, params(1)) match { case (l: Expression, r: Expression) => val lt = getExpressionType(ctx, l) val rt = getExpressionType(ctx, r) - (l, r, lt.size max rt.size, lt.isSigned || rt.isSigned) + (lt.size max rt.size, lt.isSigned || rt.isSigned) } } @@ -902,37 +899,47 @@ object MlCompiler { DecimalBuiltIns.compileByteShiftRight(ctx, l, r, rotate = false) case "<" => // TODO: signed - val (l, r, size, signed) = assertComparison(ctx, params) - size match { - case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + val (size, signed) = assertComparison(ctx, params) + compileTransitiveRelation(ctx, "<", params, exprTypeAndVariable, branches) { (l, r) => + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessSigned else ComparisonType.LessUnsigned, l, r, branches) + } } case ">=" => // TODO: signed - val (l, r, size, signed) = assertComparison(ctx, params) - size match { - case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + val (size, signed) = assertComparison(ctx, params) + compileTransitiveRelation(ctx, ">=", params, exprTypeAndVariable, branches) { (l, r) => + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterOrEqualSigned else ComparisonType.GreaterOrEqualUnsigned, l, r, branches) + } } case ">" => // TODO: signed - val (l, r, size, signed) = assertComparison(ctx, params) - size match { - case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + val (size, signed) = assertComparison(ctx, params) + compileTransitiveRelation(ctx, ">", params, exprTypeAndVariable, branches) { (l, r) => + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.GreaterSigned else ComparisonType.GreaterUnsigned, l, r, branches) + } } case "<=" => // TODO: signed - val (l, r, size, signed) = assertComparison(ctx, params) - size match { - case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + val (size, signed) = assertComparison(ctx, params) + compileTransitiveRelation(ctx, "<=", params, exprTypeAndVariable, branches) { (l, r) => + size match { + case 1 => BuiltIns.compileByteComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, if (signed) ComparisonType.LessOrEqualSigned else ComparisonType.LessOrEqualUnsigned, l, r, branches) + } } case "==" => - val (l, r, size) = assertBinary(ctx, params) - size match { - case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches) - case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches) + val size = params.map(p => getExpressionType(ctx, p).size).max + compileTransitiveRelation(ctx, "==", params, exprTypeAndVariable, branches) { (l, r) => + size match { + case 1 => BuiltIns.compileByteComparison(ctx, ComparisonType.Equal, l, r, branches) + case 2 => BuiltIns.compileWordComparison(ctx, ComparisonType.Equal, l, r, branches) + } } case "!=" => val (l, r, size) = assertBinary(ctx, params) @@ -1138,6 +1145,23 @@ object MlCompiler { } } + private def compileTransitiveRelation(ctx: CompilationContext, + operator: String, + params: List[Expression], + exprTypeAndVariable: Option[(Type, Variable)], + branches: BranchSpec)(binary: (Expression, Expression) => List[AssemblyLine]): List[AssemblyLine] ={ + params match { + case List(l, r) => binary(l, r) + case List(_) | Nil => + ErrorReporting.fatal("") + case _ => + val conjunction = params.init.zip(params.tail).map { + case (l, r) => FunctionCallExpression(operator, List(l, r)) + }.reduceLeft((a,b) => FunctionCallExpression("&&", List(a, b))) + compile(ctx, conjunction, exprTypeAndVariable, branches) + } + } + def expressionStorageFromAX(ctx: CompilationContext, exprTypeAndVariable: Option[(Type, Variable)], position: Option[Position]): List[AssemblyLine] = { exprTypeAndVariable.fold(noop) { case (VoidType, _) => ??? diff --git a/src/test/scala/millfork/test/ComparisonSuite.scala b/src/test/scala/millfork/test/ComparisonSuite.scala index c2651f07..43f5506c 100644 --- a/src/test/scala/millfork/test/ComparisonSuite.scala +++ b/src/test/scala/millfork/test/ComparisonSuite.scala @@ -261,4 +261,36 @@ class ComparisonSuite extends FunSuite with Matchers { """.stripMargin EmuBenchmarkRun(src)(_.readWord(0xc000) should equal(src.count(_ == '+'))) } + + test("Multiple params for equality") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 5 + | if (output == 5 == 5) { + | output += 1 + | } + | if (output == 5 == 6) { + | output += 78 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(6)) + } + + test("Multiple params for inequality") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | output = 5 + | if 2 < 3 < 4 { + | output += 1 + | } + | if 2 < 3 < 2 { + | output += 78 + | } + | } + """.stripMargin)(_.readWord(0xc000) should equal(6)) + } }