mirror of
https://github.com/KarolS/millfork.git
synced 2025-03-24 10:33:53 +00:00
Optimize constant comparisons
This commit is contained in:
parent
438d8dbe6e
commit
bf5639761b
@ -196,6 +196,44 @@ object AbstractExpressionCompiler {
|
||||
def getExpressionType(env: Environment, log: Logger, expr: Expression): Type = {
|
||||
val b = env.get[Type]("byte")
|
||||
val bool = env.get[Type]("bool$")
|
||||
val boolTrue = env.get[Type]("true$")
|
||||
val boolFalse= env.get[Type]("false$")
|
||||
def toType(x: Boolean): Type = if (x) boolTrue else boolFalse
|
||||
|
||||
def toAllNumericConstants(exprs: List[Expression]): Option[List[Long]] = {
|
||||
for {
|
||||
maybeConstants <- Some(exprs.map(env.eval))
|
||||
if maybeConstants.forall(_.isDefined)
|
||||
constants = maybeConstants.map(_.get)
|
||||
if constants.forall(_.isInstanceOf[NumericConstant])
|
||||
numericConstants = constants.map(_.asInstanceOf[NumericConstant])
|
||||
maxSize = numericConstants.map(_.requiredSize).fold(1)(_ max _)
|
||||
mask = (1L << (8 * maxSize)) - 1
|
||||
signMask = ~mask
|
||||
signTest = if (signMask == 0) Long.MaxValue else signMask >> 1
|
||||
types = exprs.map(e => getExpressionType(env, log, e))
|
||||
if types.forall(_.size.<(8))
|
||||
signednesses = types.flatMap {
|
||||
case d: DerivedPlainType => Some(d.isSigned)
|
||||
case _ => None
|
||||
}.toSet
|
||||
if signednesses.size < 2
|
||||
preserveSigns = signednesses.contains(true)
|
||||
values = numericConstants.map(n => if (preserveSigns && n.value.&(signTest) != 0) n.value.&(mask).|(signMask) else n.value.&(mask))
|
||||
// _ = println(s"$constants $types $signednesses $preserveSigns $mask $signMask $signTest $values")
|
||||
} yield values
|
||||
}
|
||||
def toAllBooleanConstants(exprs: List[Expression]): Option[List[Boolean]] = {
|
||||
for {
|
||||
types <- Some(exprs.map(e => getExpressionType(env, log, e)))
|
||||
if types.forall(_.isInstanceOf[ConstantBooleanType])
|
||||
bools = types.map(_.asInstanceOf[ConstantBooleanType].value)
|
||||
if bools.nonEmpty
|
||||
} yield bools
|
||||
}
|
||||
|
||||
def monotonous(values: List[Long], pred: (Long, Long) => Boolean): Boolean = values.init.zip(values.tail).forall(pred.tupled)
|
||||
|
||||
val v = env.get[Type]("void")
|
||||
val w = env.get[Type]("word")
|
||||
expr match {
|
||||
@ -225,7 +263,11 @@ object AbstractExpressionCompiler {
|
||||
case _ => log.error("Adding values bigger than words", expr.position); w
|
||||
}
|
||||
case FunctionCallExpression("nonet", params) => w
|
||||
case FunctionCallExpression("not", params) => bool
|
||||
case FunctionCallExpression("not", params) =>
|
||||
toAllBooleanConstants(params) match {
|
||||
case Some(List(x)) => toType(!x)
|
||||
case _ => bool
|
||||
}
|
||||
case FunctionCallExpression("hi", params) => b
|
||||
case FunctionCallExpression("lo", params) => b
|
||||
case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match {
|
||||
@ -246,15 +288,27 @@ object AbstractExpressionCompiler {
|
||||
case FunctionCallExpression("<<'", params) => b
|
||||
case FunctionCallExpression(">>'", params) => b
|
||||
case FunctionCallExpression(">>>>", params) => b
|
||||
case FunctionCallExpression("&&", params) => bool
|
||||
case FunctionCallExpression("||", params) => bool
|
||||
case FunctionCallExpression("^^", params) => bool
|
||||
case FunctionCallExpression("==", params) => bool
|
||||
case FunctionCallExpression("!=", params) => bool
|
||||
case FunctionCallExpression("<", params) => bool
|
||||
case FunctionCallExpression(">", params) => bool
|
||||
case FunctionCallExpression("<=", params) => bool
|
||||
case FunctionCallExpression(">=", params) => bool
|
||||
case FunctionCallExpression("&&", params) =>
|
||||
toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ && _)))
|
||||
case FunctionCallExpression("||", params) =>
|
||||
toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ || _)))
|
||||
case FunctionCallExpression("^^", params) =>
|
||||
toAllBooleanConstants(params).fold(bool)(list => toType(list.reduce(_ != _)))
|
||||
case FunctionCallExpression("==", params) =>
|
||||
toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ == _)))
|
||||
case FunctionCallExpression("!=", params) =>
|
||||
toAllNumericConstants(params) match {
|
||||
case Some(List(x, y)) => toType(x != y)
|
||||
case _ => bool
|
||||
}
|
||||
case FunctionCallExpression("<", params) =>
|
||||
toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ < _)))
|
||||
case FunctionCallExpression(">", params) =>
|
||||
toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ > _)))
|
||||
case FunctionCallExpression("<=", params) =>
|
||||
toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ <= _)))
|
||||
case FunctionCallExpression(">=", params) =>
|
||||
toAllNumericConstants(params).fold(bool)(list => toType(monotonous(list, _ >= _)))
|
||||
case FunctionCallExpression("+=", params) => v
|
||||
case FunctionCallExpression("-=", params) => v
|
||||
case FunctionCallExpression("*=", params) => v
|
||||
|
@ -62,4 +62,31 @@ class BooleanSuite extends FunSuite with Matchers {
|
||||
| }
|
||||
""".stripMargin)(_.readByte(0xc000) should equal(11))
|
||||
}
|
||||
|
||||
test("Constant conditions big suite") {
|
||||
val code ="""
|
||||
| byte output @$c000
|
||||
| const pointer outside = $bfff
|
||||
| void main () {
|
||||
| output = 1
|
||||
| if 1 == 1 { pass() }
|
||||
| if 1 == 2 { fail() }
|
||||
| if 1 != 1 { fail() }
|
||||
| if 1 != 2 { pass() }
|
||||
| if 1 < 2 { pass() }
|
||||
| if 1 < 1 { fail() }
|
||||
| if sbyte(1) < sbyte(255) { fail() }
|
||||
| if sbyte(1) > sbyte(255) { pass() }
|
||||
| if sbyte(1) > 0 { pass() }
|
||||
| if sbyte(-1) > 0 { fail() }
|
||||
| if 00001 < 00002 { pass() }
|
||||
| if 00001 > 00002 { fail() }
|
||||
| if 00001 != 00002 { pass() }
|
||||
| if 00001 == 00002 { fail() }
|
||||
| }
|
||||
| inline void pass() { output += 1 }
|
||||
| noinline void fail() { outside[0] = 0 }
|
||||
""".stripMargin
|
||||
EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)(code)(_.readByte(0xc000) should equal(code.sliding(4).count(_ == "pass")))
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user