From 4c0d184c47ee5d39bc27eea57c28ae8f8e0cdfd6 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Thu, 18 Jan 2018 22:38:17 +0100 Subject: [PATCH] Constant evaluation fixes --- .../scala/millfork/compiler/BuiltIns.scala | 38 ++++++++++++------- src/main/scala/millfork/env/Environment.scala | 34 +++++++++++------ 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/main/scala/millfork/compiler/BuiltIns.scala b/src/main/scala/millfork/compiler/BuiltIns.scala index e7969783..7ba63955 100644 --- a/src/main/scala/millfork/compiler/BuiltIns.scala +++ b/src/main/scala/millfork/compiler/BuiltIns.scala @@ -291,7 +291,8 @@ object BuiltIns { val env = ctx.env val b = env.get[Type]("byte") val firstParamCompiled = MlCompiler.compile(ctx, lhs, Some(b -> RegisterVariable(Register.A, b)), NoBranching) - env.eval(rhs) match { + val maybeConstant = env.eval(rhs) + maybeConstant match { case Some(NumericConstant(0, _)) => compType match { case ComparisonType.LessUnsigned => @@ -318,19 +319,30 @@ object BuiltIns { } case _ => } - val secondParamCompiledUnoptimized = simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false) - val secondParamCompiled = compType match { - case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned => - secondParamCompiledUnoptimized match { - case List(AssemblyLine(CMP, Immediate, NumericConstant(0, _), true)) => - if (OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) { - Nil - } else { - secondParamCompiledUnoptimized - } - case _ => secondParamCompiledUnoptimized + val secondParamCompiled = maybeConstant match { + case Some(x) => + compType match { + case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned => + if (x.quickSimplify.isLowestByteAlwaysEqual(0) && OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) Nil + else List(AssemblyLine.immediate(CMP, x)) + case _ => + List(AssemblyLine.immediate(CMP, x)) } - case _ => secondParamCompiledUnoptimized + case _ => compType match { + case ComparisonType.Equal | ComparisonType.NotEqual | ComparisonType.LessSigned | ComparisonType.GreaterOrEqualSigned => + val secondParamCompiledUnoptimized = simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false) + secondParamCompiledUnoptimized match { + case List(AssemblyLine(CMP, Immediate, NumericConstant(0, _), true)) => + if (OpcodeClasses.ChangesAAlways(firstParamCompiled.last.opcode)) { + Nil + } else { + secondParamCompiledUnoptimized + } + case _ => secondParamCompiledUnoptimized + } + case _ => + simpleOperation(CMP, ctx, rhs, IndexChoice.PreferY, preserveA = true, commutative = false) + } } val (effectiveComparisonType, label) = branches match { case NoBranching => return Nil diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 917d3a13..e1bb286c 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -270,6 +270,16 @@ class Environment(val parent: Option[Environment], val prefix: String) { } yield hc.asl(8) + lc case FunctionCallExpression(name, params) => name match { + case ">>'" => + constantOperation(MathOperator.DecimalShr, params) + case "<<'" => + constantOperation(MathOperator.DecimalShl, params) + case ">>" => + constantOperation(MathOperator.Shr, params) + case "<<" => + constantOperation(MathOperator.Shl, params) + case "*'" => + constantOperation(MathOperator.DecimalTimes, params) case "*" => constantOperation(MathOperator.Times, params) case "&&" | "&" => @@ -282,7 +292,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { None } } - } + }.map(_.quickSimplify) private def constantOperation(op: MathOperator.Value, params: List[Expression]) = { params.map(eval).reduceLeft[Option[Constant]] { (oc, om) => @@ -515,10 +525,10 @@ class Environment(val parent: Option[Environment], val prefix: String) { } addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position) addThing(ConstantThing(stmt.name, a, p), stmt.position) - addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".hi", a.hiByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".lo", a.loByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.lo", a.loByte.quickSimplify, b), stmt.position) if (length < 256) { addThing(ConstantThing(stmt.name + ".length", lengthConst, b), stmt.position) } @@ -549,10 +559,10 @@ class Environment(val parent: Option[Environment], val prefix: String) { } addThing(RelativeVariable(stmt.name + ".first", a, b, zeropage = false), stmt.position) addThing(ConstantThing(stmt.name, a, p), stmt.position) - addThing(ConstantThing(stmt.name + ".hi", a.hiByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".lo", a.loByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte, b), stmt.position) - addThing(ConstantThing(stmt.name + ".array.lo", a.loByte, b), stmt.position) + addThing(ConstantThing(stmt.name + ".hi", a.hiByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".lo", a.loByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.hi", a.hiByte.quickSimplify, b), stmt.position) + addThing(ConstantThing(stmt.name + ".array.lo", a.loByte.quickSimplify, b), stmt.position) if (length < 256) { addThing(ConstantThing(stmt.name + ".length", NumericConstant(length, 1), b), stmt.position) } @@ -588,9 +598,9 @@ class Environment(val parent: Option[Environment], val prefix: String) { val constantValue: Constant = stmt.initialValue.flatMap(eval).getOrElse(Constant.error(s"`$name` has a non-constant value", position)) if (constantValue.requiredSize > typ.size) ErrorReporting.error(s"`$name` is has an invalid value: not in the range of `$typ`", position) addThing(ConstantThing(prefix + name, constantValue, typ), stmt.position) - if (typ.size == 2) { - addThing(ConstantThing(prefix + name + ".hi", constantValue + 1, b), stmt.position) - addThing(ConstantThing(prefix + name + ".lo", constantValue, b), stmt.position) + if (typ.size >= 2) { + addThing(ConstantThing(prefix + name + ".hi", constantValue.hiByte, b), stmt.position) + addThing(ConstantThing(prefix + name + ".lo", constantValue.loByte, b), stmt.position) } } else { if (stmt.stack && stmt.global) ErrorReporting.error(s"`$name` is static or global and cannot be on stack", position)