From bf5639761b167a83edac10a18820a18779800de0 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Wed, 26 Dec 2018 01:01:43 +0100 Subject: [PATCH] Optimize constant comparisons --- .../compiler/AbstractExpressionCompiler.scala | 74 ++++++++++++++++--- .../scala/millfork/test/BooleanSuite.scala | 27 +++++++ 2 files changed, 91 insertions(+), 10 deletions(-) diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 63ec966e..c662b48b 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -196,6 +196,44 @@ object AbstractExpressionCompiler { def getExpressionType(env: Environment, log: Logger, expr: Expression): Type = { val b = env.get[Type]("byte") val bool = env.get[Type]("bool$") + val boolTrue = env.get[Type]("true$") + val boolFalse= env.get[Type]("false$") + def toType(x: Boolean): Type = if (x) boolTrue else boolFalse + + def toAllNumericConstants(exprs: List[Expression]): Option[List[Long]] = { + for { + maybeConstants <- Some(exprs.map(env.eval)) + if maybeConstants.forall(_.isDefined) + constants = maybeConstants.map(_.get) + if constants.forall(_.isInstanceOf[NumericConstant]) + numericConstants = constants.map(_.asInstanceOf[NumericConstant]) + maxSize = numericConstants.map(_.requiredSize).fold(1)(_ max _) + mask = (1L << (8 * maxSize)) - 1 + signMask = ~mask + signTest = if (signMask == 0) Long.MaxValue else signMask >> 1 + types = exprs.map(e => getExpressionType(env, log, e)) + if types.forall(_.size.<(8)) + signednesses = types.flatMap { + case d: DerivedPlainType => Some(d.isSigned) + case _ => None + }.toSet + if signednesses.size < 2 + preserveSigns = signednesses.contains(true) + values = numericConstants.map(n => if (preserveSigns && n.value.&(signTest) != 0) n.value.&(mask).|(signMask) else n.value.&(mask)) +// _ = println(s"$constants $types $signednesses $preserveSigns $mask $signMask $signTest $values") + } yield values + } + def toAllBooleanConstants(exprs: List[Expression]): Option[List[Boolean]] = { + for { + types <- Some(exprs.map(e => getExpressionType(env, log, e))) + if types.forall(_.isInstanceOf[ConstantBooleanType]) + bools = types.map(_.asInstanceOf[ConstantBooleanType].value) + if bools.nonEmpty + } yield bools + } + + def monotonous(values: List[Long], pred: (Long, Long) => Boolean): Boolean = values.init.zip(values.tail).forall(pred.tupled) + val v = env.get[Type]("void") val w = env.get[Type]("word") expr match { @@ -225,7 +263,11 @@ object AbstractExpressionCompiler { case _ => log.error("Adding values bigger than words", expr.position); w } case FunctionCallExpression("nonet", params) => w - case FunctionCallExpression("not", params) => bool + case FunctionCallExpression("not", params) => + toAllBooleanConstants(params) match { + case Some(List(x)) => toType(!x) + case _ => bool + } case FunctionCallExpression("hi", params) => b case FunctionCallExpression("lo", params) => b case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match { @@ -246,15 +288,27 @@ object AbstractExpressionCompiler { case FunctionCallExpression("<<'", params) => b case FunctionCallExpression(">>'", params) => b case FunctionCallExpression(">>>>", params) => b - case FunctionCallExpression("&&", params) => bool - case FunctionCallExpression("||", params) => bool - case FunctionCallExpression("^^", params) => bool - case FunctionCallExpression("==", params) => bool - case FunctionCallExpression("!=", params) => bool - case FunctionCallExpression("<", params) => bool - case FunctionCallExpression(">", params) => bool - case FunctionCallExpression("<=", params) => bool - case FunctionCallExpression(">=", params) => bool + case FunctionCallExpression("&&", params) => + toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ && _))) + case FunctionCallExpression("||", params) => + toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ || _))) + case FunctionCallExpression("^^", params) => + toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ != _))) + case FunctionCallExpression("==", params) => + toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ == _))) + case FunctionCallExpression("!=", params) => + toAllNumericConstants(params) match { + case Some(List(x, y)) => toType(x != y) + case _ => bool + } + case FunctionCallExpression("<", params) => + toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ < _))) + case FunctionCallExpression(">", params) => + toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ > _))) + case FunctionCallExpression("<=", params) => + toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ <= _))) + case FunctionCallExpression(">=", params) => + toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ >= _))) case FunctionCallExpression("+=", params) => v case FunctionCallExpression("-=", params) => v case FunctionCallExpression("*=", params) => v diff --git a/src/test/scala/millfork/test/BooleanSuite.scala b/src/test/scala/millfork/test/BooleanSuite.scala index 30edcc19..d58f9e82 100644 --- a/src/test/scala/millfork/test/BooleanSuite.scala +++ b/src/test/scala/millfork/test/BooleanSuite.scala @@ -62,4 +62,31 @@ class BooleanSuite extends FunSuite with Matchers { | } """.stripMargin)(_.readByte(0xc000) should equal(11)) } + + test("Constant conditions big suite") { + val code =""" + | byte output @$c000 + | const pointer outside = $bfff + | void main () { + | output = 1 + | if 1 == 1 { pass() } + | if 1 == 2 { fail() } + | if 1 != 1 { fail() } + | if 1 != 2 { pass() } + | if 1 < 2 { pass() } + | if 1 < 1 { fail() } + | if sbyte(1) < sbyte(255) { fail() } + | if sbyte(1) > sbyte(255) { pass() } + | if sbyte(1) > 0 { pass() } + | if sbyte(-1) > 0 { fail() } + | if 00001 < 00002 { pass() } + | if 00001 > 00002 { fail() } + | if 00001 != 00002 { pass() } + | if 00001 == 00002 { fail() } + | } + | inline void pass() { output += 1 } + | noinline void fail() { outside[0] = 0 } + """.stripMargin + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(code)(_.readByte(0xc000) should equal(code.sliding(4).count(_ == "pass"))) + } }