diff --git a/docs/lang/literals.md b/docs/lang/literals.md index f878cf56..1aff648d 100644 --- a/docs/lang/literals.md +++ b/docs/lang/literals.md @@ -199,7 +199,7 @@ Fields of arithmetic, pointer and enum types are declared using normal expressio Fields of struct types are declared using struct constructors. Fields of union types cannot be declared. -What might be useful is the fact that the compiler allows for built-in trigonometric functions +What might be useful is the fact that the compiler allows for certain built-in functions in constant expressions only: * `sin(x, n)` – returns _n_·**sin**(*x*π/128) @@ -208,3 +208,7 @@ in constant expressions only: * `tan(x, n)` – returns _n_·**tan**(*x*π/128) +* `min(x,...)` – returns the smallest argument + +* `max(x,...)` – returns the largest argument + diff --git a/docs/lang/preprocessor.md b/docs/lang/preprocessor.md index 9c0682df..e305fc65 100644 --- a/docs/lang/preprocessor.md +++ b/docs/lang/preprocessor.md @@ -135,7 +135,7 @@ The `if` function returns its second parameter if the first parameter is defined #infoeval if(0, 400, 500) TODO -`not`, `lo`, `hi`, `+`, `-`, `*`, `|`, `&`, `^`, `||`, `&&`, `<<`, `>>`,`==`, `!=`, `>`, `>=`, `<`, `<=` +`not`, `lo`, `hi`, `min`, `max` `+`, `-`, `*`, `|`, `&`, `^`, `||`, `&&`, `<<`, `>>`,`==`, `!=`, `>`, `>=`, `<`, `<=` The following Millfork operators and functions are not available in the preprocessor: `+'`, `-'`, `*'`, `<<'`, `>>'`, `:`, `>>>>`, `nonet`, all the assignment operators diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 8dedf4ec..5b9ed625 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -454,6 +454,16 @@ object AbstractExpressionCompiler { case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely) case FunctionCallExpression("cos", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely) case FunctionCallExpression("tan", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely) + case FunctionCallExpression("min" | "max", params) => if (params.isEmpty) b else params.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match { + case 1 => b + case 2 => w + case n if n >= 3 => env.get[Type]("int" + n * 8) + } // TODO: ? + case FunctionCallExpression("if", params) => if (params.length < 3) b else params.tail.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match { + case 1 => b + case 2 => w + case n if n >= 3 => env.get[Type]("int" + n * 8) + } // TODO: ? case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match { case 1 => b case 2 => w @@ -584,7 +594,11 @@ object AbstractExpressionCompiler { def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = { val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x) env.lookupFunction(f.functionName, paramsWithTypes).getOrElse { - log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position) + if (Environment.constOnlyBuiltinFunction(f.functionName)){ + log.error(s"Cannot use function `${f.functionName}` with non-constant params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position) + } else { + log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position) + } val signature = NormalParamSignature(paramsWithTypes.map { case (t, _) => UninitializedMemoryVariable("?", t, VariableAllocationMethod.Auto, None, NoAlignment, isVolatile = false) }) diff --git a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala index b0730901..34170361 100644 --- a/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala +++ b/src/main/scala/millfork/compiler/AbstractStatementPreprocessor.scala @@ -561,6 +561,11 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte // Eliminating variables may eliminate carry FunctionCallExpression("nonet", args.map(arg => optimizeExpr(arg, Map()))).pos(pos) case FunctionCallExpression(name, args) => + if (Environment.constOnlyBuiltinFunction(name)) { + if (ctx.env.eval(expr).isEmpty) { + ctx.log.error(s"`$name` should be only used with constant expressions", expr.position) + } + } ctx.env.maybeGet[Thing](name) match { case Some(_: MacroFunction) => FunctionCallExpression(name, args.map(arg => optimizeExpr(arg, Map()))).pos(pos) diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 01b07fec..142f599d 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -378,6 +378,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant { object MathOperator extends Enumeration { val Plus, Minus, Times, Shl, Shr, Shl9, Shr9, Plus9, DecimalPlus9, DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShl9, DecimalShr, + Minimum, Maximum, Divide, Modulo, And, Or, Exor = Value } @@ -469,12 +470,13 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.And => Constant.Zero case MathOperator.Divide => Constant.Zero case MathOperator.Modulo => Constant.Zero - case _ => CompoundConstant(operator, l, r) + case _ => quickSimplify2(l, r) } case (NumericConstant(0, _), c) => operator match { case MathOperator.Shl => l - case _ => CompoundConstant(operator, l, r) + case MathOperator.Times => l + case _ => quickSimplify2(l, r) } case (c, NumericConstant(0, 1)) => operator match { @@ -492,56 +494,63 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co case MathOperator.Exor => c case MathOperator.Or => c case MathOperator.And => Constant.Zero - case _ => CompoundConstant(operator, l, r) + case _ => quickSimplify2(l, r) } case (c, NumericConstant(1, 1)) => operator match { case MathOperator.Times => c - case _ => CompoundConstant(operator, l, r) + case MathOperator.Divide => c + case MathOperator.Modulo => Constant.Zero + case _ => quickSimplify2(l, r) } case (NumericConstant(1, 1), c) => operator match { case MathOperator.Times => c - case _ => CompoundConstant(operator, l, r) + case _ => quickSimplify2(l, r) } - case (NumericConstant(lv, ls), NumericConstant(rv, rs)) => - var size = ls max rs - val bitmask = (1L << (8*size)) - 1 - val value = operator match { - case MathOperator.Plus => lv + rv - case MathOperator.Minus => lv - rv - case MathOperator.Times => lv * rv - case MathOperator.Shl => lv << rv - case MathOperator.Shr => lv >> rv - case MathOperator.Shl9 => (lv << rv) & 0x1ff - case MathOperator.Plus9 => (lv + rv) & 0x1ff - case MathOperator.Shr9 => (lv & 0x1ff) >> rv - case MathOperator.Exor => (lv ^ rv) & bitmask - case MathOperator.Or => lv | rv - case MathOperator.And => lv & rv & bitmask - case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv - case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv - case MathOperator.DecimalPlus if ls == 1 && rs == 1 => - asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff - case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) => - asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff - case _ => return this - } - operator match { - case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 => - size = 2 - case MathOperator.Times | MathOperator.Shl => - val mask = (1 << (size * 8)) - 1 - if (value != (value & mask)) { - size = ls + rs - } - case _ => - } - NumericConstant(value, size) - case _ => CompoundConstant(operator, l, r) + case _ => quickSimplify2(l, r) } } + private def quickSimplify2(l: Constant, r: Constant): Constant = (l, r) match { + case (NumericConstant(lv, ls), NumericConstant(rv, rs)) => + var size = ls max rs + val bitmask = (1L << (8*size)) - 1 + val value = operator match { + case MathOperator.Minimum => lv min rv + case MathOperator.Maximum => lv max rv + case MathOperator.Plus => lv + rv + case MathOperator.Minus => lv - rv + case MathOperator.Times => lv * rv + case MathOperator.Shl => lv << rv + case MathOperator.Shr => lv >> rv + case MathOperator.Shl9 => (lv << rv) & 0x1ff + case MathOperator.Plus9 => (lv + rv) & 0x1ff + case MathOperator.Shr9 => (lv & 0x1ff) >> rv + case MathOperator.Exor => (lv ^ rv) & bitmask + case MathOperator.Or => lv | rv + case MathOperator.And => lv & rv & bitmask + case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv + case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv + case MathOperator.DecimalPlus if ls == 1 && rs == 1 => + asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff + case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) => + asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff + case _ => return this + } + operator match { + case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 => + size = 2 + case MathOperator.Times | MathOperator.Shl => + val mask = (1 << (size * 8)) - 1 + if (value != (value & mask)) { + size = ls + rs + } + case _ => + } + NumericConstant(value, size) + case _ => CompoundConstant(operator, l, r) + } import MathOperator._ diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index ffbf9850..0f9b9b28 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -529,6 +529,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val hudson_transfer$ = StructType("hudson_transfer$", List(FieldDesc("word", "a", None), FieldDesc("word", "b", None), FieldDesc("word", "c", None))) addThing(byte_and_pointer$, None) addThing(hudson_transfer$, None) + Environment.constOnlyBuiltinFunction.foreach(n => addThing(ConstOnlyCallable(n), None)) builtinsAdded = true } @@ -699,7 +700,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa lc <- evalImpl(l, vv) hc <- evalImpl(h, vv) } yield hc.asl(8) + lc - case FunctionCallExpression(name, params) => + case fce@FunctionCallExpression(name, params) => name match { case "sizeof" => if (params.size == 1) { @@ -758,6 +759,25 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa log.error("Invalid number of parameters for `tan`", e.position) None } + case "min" => + constantOperation(MathOperator.Minimum, fce) + case "max" => + constantOperation(MathOperator.Maximum, fce) + case "if" => + if (params.size == 3) { + eval(params(0)).map(_.quickSimplify) match { + case Some(NumericConstant(cond, _)) => + eval(params(if (cond != 0) 1 else 2)) + case Some(c) => + if (c.isProvablyGreaterOrEqualThan(1)) eval(params(1)) + else if (c.isProvablyZero) eval(params(2)) + else None + case _ => None + } + } else { + log.error("Invalid number of parameters for `if`", e.position) + None + } case "nonet" => params match { case List(FunctionCallExpression("<<", ps@List(_, _))) => @@ -840,7 +860,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } - private def constantOperation(op: MathOperator.Value, params: List[Expression]) = { + private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression): Option[Constant] = { + val params = fce.expressions + if (params.isEmpty) { + log.error(s"Invalid number of parameters for `${fce.functionName}`", fce.position) + None + } + constantOperation(op, fce.expressions) + } + + private def constantOperation(op: MathOperator.Value, params: List[Expression]): Option[Constant] = { params.map(eval).reduceLeft[Option[Constant]] { (oc, om) => for { c <- oc @@ -1066,6 +1095,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa val w = get[Type]("word") val p = get[Type]("pointer") val name = stmt.name + if (Environment.constOnlyBuiltinFunction(name)) { + log.error(s"Cannot redefine a built-in function `$name`", stmt.position) + } val resultType = get[Type](stmt.resultType) if (stmt.name == "main") { if (stmt.resultType != "void" && options.flag(CompilationFlag.UselessCodeWarning)) { @@ -2006,7 +2038,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = { if (things.contains(name)) { - val function = get[MangledFunction](name) + val thing = get[Thing](name) + if (!thing.isInstanceOf[MangledFunction]) { + return None + } + val function = thing.asInstanceOf[MangledFunction] if (function.params.length != actualParams.length) { log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position)) } @@ -2326,7 +2362,9 @@ object Environment { // built-in special-cased functions; can be considered keywords by some: val predefinedFunctions: Set[String] = Set("not", "hi", "lo", "nonet", "sizeof") // built-in special-cased functions, not keywords, but assumed to work almost as such: - val specialFunctions: Set[String] = Set("sin", "cos", "tan", "call") + val specialFunctions: Set[String] = Set("call") + // functions that exist only in constants: + val constOnlyBuiltinFunction: Set[String] = Set("sin", "cos", "tan", "min", "max") // keywords: val neverIdentifiers: Set[String] = Set( "array", "const", "alias", "import", "static", "register", "stack", "volatile", "asm", "extern", "kernal_interrupt", "interrupt", "reentrant", "segment", diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index 01936242..1a54bef4 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -18,6 +18,8 @@ sealed trait VariableLikeThing extends Thing sealed trait IndexableThing extends Thing +case class ConstOnlyCallable(val name: String) extends CallableThing + sealed trait Type extends CallableThing { def size: Int diff --git a/src/main/scala/millfork/output/AbstractAssembler.scala b/src/main/scala/millfork/output/AbstractAssembler.scala index 073e44fe..a13c36f8 100644 --- a/src/main/scala/millfork/output/AbstractAssembler.scala +++ b/src/main/scala/millfork/output/AbstractAssembler.scala @@ -158,6 +158,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program case MathOperator.Shl9 => (l << r) & 0x1ff case MathOperator.Shr => l >>> r case MathOperator.Shr9 => ((l & 0x1ff) >>> r) & 0xff + case MathOperator.Minimum => l min r + case MathOperator.Maximum => l max r case MathOperator.DecimalPlus => asDecimal(l, r, _ + _) case MathOperator.DecimalPlus9 => asDecimal(l, r, _ + _) & 0x1ff case MathOperator.DecimalMinus => asDecimal(l, r, _ - _) diff --git a/src/main/scala/millfork/parser/Preprocessor.scala b/src/main/scala/millfork/parser/Preprocessor.scala index 3ff363cf..2472a7d0 100644 --- a/src/main/scala/millfork/parser/Preprocessor.scala +++ b/src/main/scala/millfork/parser/Preprocessor.scala @@ -212,7 +212,9 @@ class PreprocessorParser(options: CompilationOptions) { case ("lo", List(p)) => {m:M => Some(p(m).getOrElse(0L) & 0xff)} case ("hi", List(p)) => {m:M => Some(p(m).getOrElse(0L).>>(8).&(0xff))} case ("if", List(i, t, e)) => {m:M => if (i(m).getOrElse(0L) != 0) t(m) else e(m)} - case ("defined" | "lo" | "hi" | "not" | "if", ps) => + case ("min", ps@(_::_)) => {m:M => ps.map(_(m)).min} + case ("max", ps@(_::_)) => {m:M => ps.map(_(m)).max} + case ("defined" | "lo" | "hi" | "not" | "if" | "min" | "max", ps) => log.error(s"Invalid number of parameters to $name: ${ps.length}") alwaysNone case _ => diff --git a/src/test/scala/millfork/test/ConstantSuite.scala b/src/test/scala/millfork/test/ConstantSuite.scala index efe15d7e..e0fa9b55 100644 --- a/src/test/scala/millfork/test/ConstantSuite.scala +++ b/src/test/scala/millfork/test/ConstantSuite.scala @@ -43,4 +43,30 @@ class ConstantSuite extends FunSuite with Matchers { |} |""".stripMargin) } + test("Special const functions should work") { + EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)( + """ + | const array values = [111, if (0,1,2), if(1,2,3), min(2,3,4), max(2,3,4) + | pointer output @$c000 + | void main() { + | + | } + """.stripMargin){m => + val arrayStart = m.readWord(0xc000) + m.readByte(arrayStart + 1) should equal(2) + m.readByte(arrayStart + 2) should equal(2) + m.readByte(arrayStart + 3) should equal(2) + m.readByte(arrayStart + 4) should equal(4) + } + } + + test("Do not compile const functions with variables") { + ShouldNotCompile( + """ + | byte output + | void main() { + | min(output,0) + | } + """.stripMargin) + } }