1
0
mirror of https://github.com/KarolS/millfork.git synced 2026-04-20 18:16:35 +00:00

Added constant-only functions min, max, if. Improved handling of constant-only functions.

This commit is contained in:
Karol Stasiak
2020-03-19 19:43:24 +01:00
parent 9cd1e47a37
commit 85030d3147
10 changed files with 150 additions and 48 deletions
@@ -454,6 +454,16 @@ object AbstractExpressionCompiler {
case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("cos", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("tan", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("min" | "max", params) => if (params.isEmpty) b else params.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match {
case 1 => b
case 2 => w
case n if n >= 3 => env.get[Type]("int" + n * 8)
} // TODO: ?
case FunctionCallExpression("if", params) => if (params.length < 3) b else params.tail.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match {
case 1 => b
case 2 => w
case n if n >= 3 => env.get[Type]("int" + n * 8)
} // TODO: ?
case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match {
case 1 => b
case 2 => w
@@ -584,7 +594,11 @@ object AbstractExpressionCompiler {
def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = {
val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x)
env.lookupFunction(f.functionName, paramsWithTypes).getOrElse {
log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
if (Environment.constOnlyBuiltinFunction(f.functionName)){
log.error(s"Cannot use function `${f.functionName}` with non-constant params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
} else {
log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
}
val signature = NormalParamSignature(paramsWithTypes.map { case (t, _) =>
UninitializedMemoryVariable("?", t, VariableAllocationMethod.Auto, None, NoAlignment, isVolatile = false)
})
@@ -561,6 +561,11 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
// Eliminating variables may eliminate carry
FunctionCallExpression("nonet", args.map(arg => optimizeExpr(arg, Map()))).pos(pos)
case FunctionCallExpression(name, args) =>
if (Environment.constOnlyBuiltinFunction(name)) {
if (ctx.env.eval(expr).isEmpty) {
ctx.log.error(s"`$name` should be only used with constant expressions", expr.position)
}
}
ctx.env.maybeGet[Thing](name) match {
case Some(_: MacroFunction) =>
FunctionCallExpression(name, args.map(arg => optimizeExpr(arg, Map()))).pos(pos)
+49 -40
View File
@@ -378,6 +378,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant {
object MathOperator extends Enumeration {
val Plus, Minus, Times, Shl, Shr, Shl9, Shr9, Plus9, DecimalPlus9,
DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShl9, DecimalShr,
Minimum, Maximum,
Divide, Modulo,
And, Or, Exor = Value
}
@@ -469,12 +470,13 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
case MathOperator.And => Constant.Zero
case MathOperator.Divide => Constant.Zero
case MathOperator.Modulo => Constant.Zero
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (NumericConstant(0, _), c) =>
operator match {
case MathOperator.Shl => l
case _ => CompoundConstant(operator, l, r)
case MathOperator.Times => l
case _ => quickSimplify2(l, r)
}
case (c, NumericConstant(0, 1)) =>
operator match {
@@ -492,56 +494,63 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
case MathOperator.Exor => c
case MathOperator.Or => c
case MathOperator.And => Constant.Zero
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (c, NumericConstant(1, 1)) =>
operator match {
case MathOperator.Times => c
case _ => CompoundConstant(operator, l, r)
case MathOperator.Divide => c
case MathOperator.Modulo => Constant.Zero
case _ => quickSimplify2(l, r)
}
case (NumericConstant(1, 1), c) =>
operator match {
case MathOperator.Times => c
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
var size = ls max rs
val bitmask = (1L << (8*size)) - 1
val value = operator match {
case MathOperator.Plus => lv + rv
case MathOperator.Minus => lv - rv
case MathOperator.Times => lv * rv
case MathOperator.Shl => lv << rv
case MathOperator.Shr => lv >> rv
case MathOperator.Shl9 => (lv << rv) & 0x1ff
case MathOperator.Plus9 => (lv + rv) & 0x1ff
case MathOperator.Shr9 => (lv & 0x1ff) >> rv
case MathOperator.Exor => (lv ^ rv) & bitmask
case MathOperator.Or => lv | rv
case MathOperator.And => lv & rv & bitmask
case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv
case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv
case MathOperator.DecimalPlus if ls == 1 && rs == 1 =>
asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff
case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) =>
asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff
case _ => return this
}
operator match {
case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 =>
size = 2
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)) {
size = ls + rs
}
case _ =>
}
NumericConstant(value, size)
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
}
private def quickSimplify2(l: Constant, r: Constant): Constant = (l, r) match {
case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
var size = ls max rs
val bitmask = (1L << (8*size)) - 1
val value = operator match {
case MathOperator.Minimum => lv min rv
case MathOperator.Maximum => lv max rv
case MathOperator.Plus => lv + rv
case MathOperator.Minus => lv - rv
case MathOperator.Times => lv * rv
case MathOperator.Shl => lv << rv
case MathOperator.Shr => lv >> rv
case MathOperator.Shl9 => (lv << rv) & 0x1ff
case MathOperator.Plus9 => (lv + rv) & 0x1ff
case MathOperator.Shr9 => (lv & 0x1ff) >> rv
case MathOperator.Exor => (lv ^ rv) & bitmask
case MathOperator.Or => lv | rv
case MathOperator.And => lv & rv & bitmask
case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv
case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv
case MathOperator.DecimalPlus if ls == 1 && rs == 1 =>
asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff
case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) =>
asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff
case _ => return this
}
operator match {
case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 =>
size = 2
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)) {
size = ls + rs
}
case _ =>
}
NumericConstant(value, size)
case _ => CompoundConstant(operator, l, r)
}
import MathOperator._
+42 -4
View File
@@ -529,6 +529,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
val hudson_transfer$ = StructType("hudson_transfer$", List(FieldDesc("word", "a", None), FieldDesc("word", "b", None), FieldDesc("word", "c", None)))
addThing(byte_and_pointer$, None)
addThing(hudson_transfer$, None)
Environment.constOnlyBuiltinFunction.foreach(n => addThing(ConstOnlyCallable(n), None))
builtinsAdded = true
}
@@ -699,7 +700,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
lc <- evalImpl(l, vv)
hc <- evalImpl(h, vv)
} yield hc.asl(8) + lc
case FunctionCallExpression(name, params) =>
case fce@FunctionCallExpression(name, params) =>
name match {
case "sizeof" =>
if (params.size == 1) {
@@ -758,6 +759,25 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
log.error("Invalid number of parameters for `tan`", e.position)
None
}
case "min" =>
constantOperation(MathOperator.Minimum, fce)
case "max" =>
constantOperation(MathOperator.Maximum, fce)
case "if" =>
if (params.size == 3) {
eval(params(0)).map(_.quickSimplify) match {
case Some(NumericConstant(cond, _)) =>
eval(params(if (cond != 0) 1 else 2))
case Some(c) =>
if (c.isProvablyGreaterOrEqualThan(1)) eval(params(1))
else if (c.isProvablyZero) eval(params(2))
else None
case _ => None
}
} else {
log.error("Invalid number of parameters for `if`", e.position)
None
}
case "nonet" =>
params match {
case List(FunctionCallExpression("<<", ps@List(_, _))) =>
@@ -840,7 +860,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
}
private def constantOperation(op: MathOperator.Value, params: List[Expression]) = {
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression): Option[Constant] = {
val params = fce.expressions
if (params.isEmpty) {
log.error(s"Invalid number of parameters for `${fce.functionName}`", fce.position)
None
}
constantOperation(op, fce.expressions)
}
private def constantOperation(op: MathOperator.Value, params: List[Expression]): Option[Constant] = {
params.map(eval).reduceLeft[Option[Constant]] { (oc, om) =>
for {
c <- oc
@@ -1066,6 +1095,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
val w = get[Type]("word")
val p = get[Type]("pointer")
val name = stmt.name
if (Environment.constOnlyBuiltinFunction(name)) {
log.error(s"Cannot redefine a built-in function `$name`", stmt.position)
}
val resultType = get[Type](stmt.resultType)
if (stmt.name == "main") {
if (stmt.resultType != "void" && options.flag(CompilationFlag.UselessCodeWarning)) {
@@ -2006,7 +2038,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = {
if (things.contains(name)) {
val function = get[MangledFunction](name)
val thing = get[Thing](name)
if (!thing.isInstanceOf[MangledFunction]) {
return None
}
val function = thing.asInstanceOf[MangledFunction]
if (function.params.length != actualParams.length) {
log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position))
}
@@ -2326,7 +2362,9 @@ object Environment {
// built-in special-cased functions; can be considered keywords by some:
val predefinedFunctions: Set[String] = Set("not", "hi", "lo", "nonet", "sizeof")
// built-in special-cased functions, not keywords, but assumed to work almost as such:
val specialFunctions: Set[String] = Set("sin", "cos", "tan", "call")
val specialFunctions: Set[String] = Set("call")
// functions that exist only in constants:
val constOnlyBuiltinFunction: Set[String] = Set("sin", "cos", "tan", "min", "max")
// keywords:
val neverIdentifiers: Set[String] = Set(
"array", "const", "alias", "import", "static", "register", "stack", "volatile", "asm", "extern", "kernal_interrupt", "interrupt", "reentrant", "segment",
+2
View File
@@ -18,6 +18,8 @@ sealed trait VariableLikeThing extends Thing
sealed trait IndexableThing extends Thing
case class ConstOnlyCallable(val name: String) extends CallableThing
sealed trait Type extends CallableThing {
def size: Int
@@ -158,6 +158,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
case MathOperator.Shl9 => (l << r) & 0x1ff
case MathOperator.Shr => l >>> r
case MathOperator.Shr9 => ((l & 0x1ff) >>> r) & 0xff
case MathOperator.Minimum => l min r
case MathOperator.Maximum => l max r
case MathOperator.DecimalPlus => asDecimal(l, r, _ + _)
case MathOperator.DecimalPlus9 => asDecimal(l, r, _ + _) & 0x1ff
case MathOperator.DecimalMinus => asDecimal(l, r, _ - _)
@@ -212,7 +212,9 @@ class PreprocessorParser(options: CompilationOptions) {
case ("lo", List(p)) => {m:M => Some(p(m).getOrElse(0L) & 0xff)}
case ("hi", List(p)) => {m:M => Some(p(m).getOrElse(0L).>>(8).&(0xff))}
case ("if", List(i, t, e)) => {m:M => if (i(m).getOrElse(0L) != 0) t(m) else e(m)}
case ("defined" | "lo" | "hi" | "not" | "if", ps) =>
case ("min", ps@(_::_)) => {m:M => ps.map(_(m)).min}
case ("max", ps@(_::_)) => {m:M => ps.map(_(m)).max}
case ("defined" | "lo" | "hi" | "not" | "if" | "min" | "max", ps) =>
log.error(s"Invalid number of parameters to $name: ${ps.length}")
alwaysNone
case _ =>
@@ -43,4 +43,30 @@ class ConstantSuite extends FunSuite with Matchers {
|}
|""".stripMargin)
}
test("Special const functions should work") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)(
"""
| const array values = [111, if (0,1,2), if(1,2,3), min(2,3,4), max(2,3,4)
| pointer output @$c000
| void main() {
|
| }
""".stripMargin){m =>
val arrayStart = m.readWord(0xc000)
m.readByte(arrayStart + 1) should equal(2)
m.readByte(arrayStart + 2) should equal(2)
m.readByte(arrayStart + 3) should equal(2)
m.readByte(arrayStart + 4) should equal(4)
}
}
test("Do not compile const functions with variables") {
ShouldNotCompile(
"""
| byte output
| void main() {
| min(output,0)
| }
""".stripMargin)
}
}