From 34ef8b8de9db1b8ec8bb17f8b8f9d8720c2174cb Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Fri, 12 Nov 2021 02:44:20 +0100 Subject: [PATCH] Better evaluation of the if function in constants --- src/main/scala/millfork/env/Constant.scala | 26 ++++++++++++- src/main/scala/millfork/env/Environment.scala | 37 +++++++++++++------ .../millfork/output/AbstractAssembler.scala | 8 ++++ 3 files changed, 58 insertions(+), 13 deletions(-) diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index d46da642..b241934b 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -440,6 +440,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant { object MathOperator extends Enumeration { val Plus, Minus, Times, Shl, Shr, Shl9, Shr9, Plus9, DecimalPlus9, DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShl9, DecimalShr, + Equal, NotEqual, Less, LessEqual, Greater, GreaterEqual, Minimum, Maximum, Divide, Modulo, And, Or, Exor = Value @@ -599,7 +600,7 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.Exor => (lv ^ rv) & bitmask case MathOperator.Or => lv | rv case MathOperator.And => lv & rv & bitmask - case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv + case MathOperator.Divide if lv >= 0 && rv > 0 => lv / rv case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv case MathOperator.DecimalPlus if ls == 1 && rs == 1 => asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff @@ -692,6 +693,12 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case Exor => s"$plhs ^ $prhs" case Divide => s"$plhs / $prhs" case Modulo => s"$plhs %% $prhs" + case Equal => s"$plhs == $prhs" + case NotEqual => s"$plhs != $prhs" + case Greater => s"$plhs > $prhs" + case GreaterEqual => s"$plhs >= $prhs" + case Less=> s"$plhs < $prhs" + case LessEqual=> s"$plhs <= $prhs" } } @@ -783,3 +790,20 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co override def extractLabels: List[String] = lhs.extractLabels ++ rhs.extractLabels } + +case class IfConstant(cond: Constant, ifTrue: Constant, ifFalse: Constant) extends Constant { + + override def toIntelString: String = s"if(${cond.toIntelString},${ifTrue.toIntelString},${ifFalse.toIntelString})" + + override def toString: String = s"if(${cond.toString},${ifTrue.toString},${ifFalse.toString})" + + override def requiredSize: Int = ifTrue.requiredSize max ifFalse.requiredSize + + override def isRelatedTo(v: Thing): Boolean = cond.isRelatedTo(v) || ifTrue.isRelatedTo(v) || ifFalse.isRelatedTo(v) + + override def refersTo(name: String): Boolean = cond.refersTo(name) || ifTrue.refersTo(name) || ifFalse.refersTo(name) + + override def extractLabels: List[String] = cond.extractLabels ++ ifTrue.extractLabels ++ ifFalse.extractLabels + + override def rootThingName: String = "?" +} \ No newline at end of file diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 3e61d960..d30b4680 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -840,7 +840,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa case Some(c) => if (c.isProvablyGreaterOrEqualThan(1)) evalImpl(params(1), vv) else if (c.isProvablyZero) evalImpl(params(2), vv) - else None + else (evalImpl(params(1), vv), evalImpl(params(2), vv)) match { + case (Some(t), Some(f)) => Some(IfConstant(c, t, f)) + case _ => None + } case _ => None } } else { @@ -887,15 +890,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa constantOperation(MathOperator.Exor, params, vv) case "||" | "|" => constantOperation(MathOperator.Or, params, vv) - case ">" => evalComparisons(params, vv, _ > _) - case "<" => evalComparisons(params, vv, _ < _) - case ">=" => evalComparisons(params, vv, _ >= _) - case "<=" => evalComparisons(params, vv, _ <= _) - case "==" => evalComparisons(params, vv, _ == _) + case ">" => evalComparisons(params, vv, MathOperator.Greater, _ > _) + case "<" => evalComparisons(params, vv, MathOperator.Less,_ < _) + case ">=" => evalComparisons(params, vv, MathOperator.GreaterEqual,_ >= _) + case "<=" => evalComparisons(params, vv, MathOperator.LessEqual,_ <= _) + case "==" => evalComparisons(params, vv, MathOperator.Equal,_ == _) case "!=" => sequence(params.map(p => evalImpl(p, vv))) match { case Some(List(NumericConstant(n1, _), NumericConstant(n2, _))) => Some(if (n1 != n2) Constant.One else Constant.Zero) + case Some(List(c1, c2)) => + Some(CompoundConstant(MathOperator.NotEqual, c1, c2)) case _ => None } case _ => @@ -950,14 +955,22 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } - private def evalComparisons(params: List[Expression], vv: Option[Map[String, Constant]], cond: (Long, Long) => Boolean): Option[Constant] = { + private def evalComparisons(params: List[Expression], vv: Option[Map[String, Constant]], operator: MathOperator.Value, cond: (Long, Long) => Boolean): Option[Constant] = { if (params.size < 2) return None - val numbers = sequence(params.map{ e => - evalImpl(e, vv) match { - case Some(NumericConstant(n, _)) => Some(n) - case _ => None - } + val paramsEvaluated = params.map { e => + evalImpl(e, vv) + } + val numbers = sequence(paramsEvaluated.map { + case Some(NumericConstant(n, _)) => Some(n) + case _ => None }) + if (numbers.isEmpty) { + paramsEvaluated match { + case List(Some(c1), Some(c2)) => + return Some(CompoundConstant(operator, c1, c2)) + case _ => + } + } numbers.map { ns => if (ns.init.zip(ns.tail).forall(cond.tupled)) Constant.One else Constant.Zero } diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index 70665d5d..283470cc 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -111,6 +111,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program } c match { case NumericConstant(v, _) => v + case IfConstant(c, t, f) => + if (deepConstResolve(c) != 0) deepConstResolve(t) else deepConstResolve(f) case AssertByte(inner) => val value = deepConstResolve(inner) if (value.toByte == value) value else { @@ -184,6 +186,12 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program case MathOperator.DecimalShl => asDecimal(l, 1 << r, _ * _) case MathOperator.DecimalShl9 => asDecimal(l, 1 << r, _ * _) & 0x1ff case MathOperator.DecimalShr => asDecimal(l, 1 << r, _ / _) + case MathOperator.Equal => if (l == r) 1 else 0 + case MathOperator.NotEqual => if (l != r) 1 else 0 + case MathOperator.Less => if (l < r) 1 else 0 + case MathOperator.LessEqual => if (l <= r) 1 else 0 + case MathOperator.Greater => if (l > r) 1 else 0 + case MathOperator.GreaterEqual => if (l >= r) 1 else 0 case MathOperator.And => l & r case MathOperator.Exor => l ^ r case MathOperator.Or => l | r