1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-01 06:29:53 +00:00

Added constant-only functions min, max, if. Improved handling of constant-only functions.

This commit is contained in:
Karol Stasiak 2020-03-19 19:43:24 +01:00
parent 9cd1e47a37
commit 85030d3147
10 changed files with 150 additions and 48 deletions

View File

@ -199,7 +199,7 @@ Fields of arithmetic, pointer and enum types are declared using normal expressio
Fields of struct types are declared using struct constructors.
Fields of union types cannot be declared.
What might be useful is the fact that the compiler allows for built-in trigonometric functions
What might be useful is the fact that the compiler allows for certain built-in functions
in constant expressions only:
* `sin(x, n)` returns _n_·**sin**(*x*π/128)
@ -208,3 +208,7 @@ in constant expressions only:
* `tan(x, n)` returns _n_·**tan**(*x*π/128)
* `min(x,...)` returns the smallest argument
* `max(x,...)` returns the largest argument

View File

@ -135,7 +135,7 @@ The `if` function returns its second parameter if the first parameter is defined
#infoeval if(0, 400, 500)
TODO
`not`, `lo`, `hi`, `+`, `-`, `*`, `|`, `&`, `^`, `||`, `&&`, `<<`, `>>`,`==`, `!=`, `>`, `>=`, `<`, `<=`
`not`, `lo`, `hi`, `min`, `max` `+`, `-`, `*`, `|`, `&`, `^`, `||`, `&&`, `<<`, `>>`,`==`, `!=`, `>`, `>=`, `<`, `<=`
The following Millfork operators and functions are not available in the preprocessor:
`+'`, `-'`, `*'`, `<<'`, `>>'`, `:`, `>>>>`, `nonet`, all the assignment operators

View File

@ -454,6 +454,16 @@ object AbstractExpressionCompiler {
case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("cos", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("tan", params) => if (params.size < 2) b else getExpressionTypeImpl(env, log, params(1), loosely)
case FunctionCallExpression("min" | "max", params) => if (params.isEmpty) b else params.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match {
case 1 => b
case 2 => w
case n if n >= 3 => env.get[Type]("int" + n * 8)
} // TODO: ?
case FunctionCallExpression("if", params) => if (params.length < 3) b else params.tail.map { e => getExpressionTypeImpl(env, log, e, loosely).size }.max match {
case 1 => b
case 2 => w
case n if n >= 3 => env.get[Type]("int" + n * 8)
} // TODO: ?
case FunctionCallExpression("sizeof", params) => env.evalSizeof(params.head).requiredSize match {
case 1 => b
case 2 => w
@ -584,7 +594,11 @@ object AbstractExpressionCompiler {
def lookupFunction(env: Environment, log: Logger, f: FunctionCallExpression): MangledFunction = {
val paramsWithTypes = f.expressions.map(x => getExpressionType(env, log, x) -> x)
env.lookupFunction(f.functionName, paramsWithTypes).getOrElse {
log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
if (Environment.constOnlyBuiltinFunction(f.functionName)){
log.error(s"Cannot use function `${f.functionName}` with non-constant params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
} else {
log.error(s"Cannot find function `${f.functionName}` with given params `${paramsWithTypes.map(_._1).mkString("(", ",", ")")}`", f.position)
}
val signature = NormalParamSignature(paramsWithTypes.map { case (t, _) =>
UninitializedMemoryVariable("?", t, VariableAllocationMethod.Auto, None, NoAlignment, isVolatile = false)
})

View File

@ -561,6 +561,11 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte
// Eliminating variables may eliminate carry
FunctionCallExpression("nonet", args.map(arg => optimizeExpr(arg, Map()))).pos(pos)
case FunctionCallExpression(name, args) =>
if (Environment.constOnlyBuiltinFunction(name)) {
if (ctx.env.eval(expr).isEmpty) {
ctx.log.error(s"`$name` should be only used with constant expressions", expr.position)
}
}
ctx.env.maybeGet[Thing](name) match {
case Some(_: MacroFunction) =>
FunctionCallExpression(name, args.map(arg => optimizeExpr(arg, Map()))).pos(pos)

View File

@ -378,6 +378,7 @@ case class SubbyteConstant(base: Constant, index: Int) extends Constant {
object MathOperator extends Enumeration {
val Plus, Minus, Times, Shl, Shr, Shl9, Shr9, Plus9, DecimalPlus9,
DecimalPlus, DecimalMinus, DecimalTimes, DecimalShl, DecimalShl9, DecimalShr,
Minimum, Maximum,
Divide, Modulo,
And, Or, Exor = Value
}
@ -469,12 +470,13 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
case MathOperator.And => Constant.Zero
case MathOperator.Divide => Constant.Zero
case MathOperator.Modulo => Constant.Zero
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (NumericConstant(0, _), c) =>
operator match {
case MathOperator.Shl => l
case _ => CompoundConstant(operator, l, r)
case MathOperator.Times => l
case _ => quickSimplify2(l, r)
}
case (c, NumericConstant(0, 1)) =>
operator match {
@ -492,56 +494,63 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
case MathOperator.Exor => c
case MathOperator.Or => c
case MathOperator.And => Constant.Zero
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (c, NumericConstant(1, 1)) =>
operator match {
case MathOperator.Times => c
case _ => CompoundConstant(operator, l, r)
case MathOperator.Divide => c
case MathOperator.Modulo => Constant.Zero
case _ => quickSimplify2(l, r)
}
case (NumericConstant(1, 1), c) =>
operator match {
case MathOperator.Times => c
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
var size = ls max rs
val bitmask = (1L << (8*size)) - 1
val value = operator match {
case MathOperator.Plus => lv + rv
case MathOperator.Minus => lv - rv
case MathOperator.Times => lv * rv
case MathOperator.Shl => lv << rv
case MathOperator.Shr => lv >> rv
case MathOperator.Shl9 => (lv << rv) & 0x1ff
case MathOperator.Plus9 => (lv + rv) & 0x1ff
case MathOperator.Shr9 => (lv & 0x1ff) >> rv
case MathOperator.Exor => (lv ^ rv) & bitmask
case MathOperator.Or => lv | rv
case MathOperator.And => lv & rv & bitmask
case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv
case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv
case MathOperator.DecimalPlus if ls == 1 && rs == 1 =>
asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff
case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) =>
asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff
case _ => return this
}
operator match {
case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 =>
size = 2
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)) {
size = ls + rs
}
case _ =>
}
NumericConstant(value, size)
case _ => CompoundConstant(operator, l, r)
case _ => quickSimplify2(l, r)
}
}
private def quickSimplify2(l: Constant, r: Constant): Constant = (l, r) match {
case (NumericConstant(lv, ls), NumericConstant(rv, rs)) =>
var size = ls max rs
val bitmask = (1L << (8*size)) - 1
val value = operator match {
case MathOperator.Minimum => lv min rv
case MathOperator.Maximum => lv max rv
case MathOperator.Plus => lv + rv
case MathOperator.Minus => lv - rv
case MathOperator.Times => lv * rv
case MathOperator.Shl => lv << rv
case MathOperator.Shr => lv >> rv
case MathOperator.Shl9 => (lv << rv) & 0x1ff
case MathOperator.Plus9 => (lv + rv) & 0x1ff
case MathOperator.Shr9 => (lv & 0x1ff) >> rv
case MathOperator.Exor => (lv ^ rv) & bitmask
case MathOperator.Or => lv | rv
case MathOperator.And => lv & rv & bitmask
case MathOperator.Divide if lv >= 0 && rv >= 0 => lv / rv
case MathOperator.Modulo if lv >= 0 && rv >= 0 => lv % rv
case MathOperator.DecimalPlus if ls == 1 && rs == 1 =>
asDecimal(lv & 0xff, rv & 0xff, _ + _) & 0xff
case MathOperator.DecimalMinus if ls == 1 && rs == 1 && lv.&(0xff) >= rv.&(0xff) =>
asDecimal(lv & 0xff, rv & 0xff, _ - _) & 0xff
case _ => return this
}
operator match {
case MathOperator.Plus9 | MathOperator.DecimalPlus9 | MathOperator.Shl9 | MathOperator.DecimalShl9 =>
size = 2
case MathOperator.Times | MathOperator.Shl =>
val mask = (1 << (size * 8)) - 1
if (value != (value & mask)) {
size = ls + rs
}
case _ =>
}
NumericConstant(value, size)
case _ => CompoundConstant(operator, l, r)
}
import MathOperator._

View File

@ -529,6 +529,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
val hudson_transfer$ = StructType("hudson_transfer$", List(FieldDesc("word", "a", None), FieldDesc("word", "b", None), FieldDesc("word", "c", None)))
addThing(byte_and_pointer$, None)
addThing(hudson_transfer$, None)
Environment.constOnlyBuiltinFunction.foreach(n => addThing(ConstOnlyCallable(n), None))
builtinsAdded = true
}
@ -699,7 +700,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
lc <- evalImpl(l, vv)
hc <- evalImpl(h, vv)
} yield hc.asl(8) + lc
case FunctionCallExpression(name, params) =>
case fce@FunctionCallExpression(name, params) =>
name match {
case "sizeof" =>
if (params.size == 1) {
@ -758,6 +759,25 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
log.error("Invalid number of parameters for `tan`", e.position)
None
}
case "min" =>
constantOperation(MathOperator.Minimum, fce)
case "max" =>
constantOperation(MathOperator.Maximum, fce)
case "if" =>
if (params.size == 3) {
eval(params(0)).map(_.quickSimplify) match {
case Some(NumericConstant(cond, _)) =>
eval(params(if (cond != 0) 1 else 2))
case Some(c) =>
if (c.isProvablyGreaterOrEqualThan(1)) eval(params(1))
else if (c.isProvablyZero) eval(params(2))
else None
case _ => None
}
} else {
log.error("Invalid number of parameters for `if`", e.position)
None
}
case "nonet" =>
params match {
case List(FunctionCallExpression("<<", ps@List(_, _))) =>
@ -840,7 +860,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
}
private def constantOperation(op: MathOperator.Value, params: List[Expression]) = {
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression): Option[Constant] = {
val params = fce.expressions
if (params.isEmpty) {
log.error(s"Invalid number of parameters for `${fce.functionName}`", fce.position)
None
}
constantOperation(op, fce.expressions)
}
private def constantOperation(op: MathOperator.Value, params: List[Expression]): Option[Constant] = {
params.map(eval).reduceLeft[Option[Constant]] { (oc, om) =>
for {
c <- oc
@ -1066,6 +1095,9 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
val w = get[Type]("word")
val p = get[Type]("pointer")
val name = stmt.name
if (Environment.constOnlyBuiltinFunction(name)) {
log.error(s"Cannot redefine a built-in function `$name`", stmt.position)
}
val resultType = get[Type](stmt.resultType)
if (stmt.name == "main") {
if (stmt.resultType != "void" && options.flag(CompilationFlag.UselessCodeWarning)) {
@ -2006,7 +2038,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
def lookupFunction(name: String, actualParams: List[(Type, Expression)]): Option[MangledFunction] = {
if (things.contains(name)) {
val function = get[MangledFunction](name)
val thing = get[Thing](name)
if (!thing.isInstanceOf[MangledFunction]) {
return None
}
val function = thing.asInstanceOf[MangledFunction]
if (function.params.length != actualParams.length) {
log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position))
}
@ -2326,7 +2362,9 @@ object Environment {
// built-in special-cased functions; can be considered keywords by some:
val predefinedFunctions: Set[String] = Set("not", "hi", "lo", "nonet", "sizeof")
// built-in special-cased functions, not keywords, but assumed to work almost as such:
val specialFunctions: Set[String] = Set("sin", "cos", "tan", "call")
val specialFunctions: Set[String] = Set("call")
// functions that exist only in constants:
val constOnlyBuiltinFunction: Set[String] = Set("sin", "cos", "tan", "min", "max")
// keywords:
val neverIdentifiers: Set[String] = Set(
"array", "const", "alias", "import", "static", "register", "stack", "volatile", "asm", "extern", "kernal_interrupt", "interrupt", "reentrant", "segment",

View File

@ -18,6 +18,8 @@ sealed trait VariableLikeThing extends Thing
sealed trait IndexableThing extends Thing
case class ConstOnlyCallable(val name: String) extends CallableThing
sealed trait Type extends CallableThing {
def size: Int

View File

@ -158,6 +158,8 @@ abstract class AbstractAssembler[T <: AbstractCode](private val program: Program
case MathOperator.Shl9 => (l << r) & 0x1ff
case MathOperator.Shr => l >>> r
case MathOperator.Shr9 => ((l & 0x1ff) >>> r) & 0xff
case MathOperator.Minimum => l min r
case MathOperator.Maximum => l max r
case MathOperator.DecimalPlus => asDecimal(l, r, _ + _)
case MathOperator.DecimalPlus9 => asDecimal(l, r, _ + _) & 0x1ff
case MathOperator.DecimalMinus => asDecimal(l, r, _ - _)

View File

@ -212,7 +212,9 @@ class PreprocessorParser(options: CompilationOptions) {
case ("lo", List(p)) => {m:M => Some(p(m).getOrElse(0L) & 0xff)}
case ("hi", List(p)) => {m:M => Some(p(m).getOrElse(0L).>>(8).&(0xff))}
case ("if", List(i, t, e)) => {m:M => if (i(m).getOrElse(0L) != 0) t(m) else e(m)}
case ("defined" | "lo" | "hi" | "not" | "if", ps) =>
case ("min", ps@(_::_)) => {m:M => ps.map(_(m)).min}
case ("max", ps@(_::_)) => {m:M => ps.map(_(m)).max}
case ("defined" | "lo" | "hi" | "not" | "if" | "min" | "max", ps) =>
log.error(s"Invalid number of parameters to $name: ${ps.length}")
alwaysNone
case _ =>

View File

@ -43,4 +43,30 @@ class ConstantSuite extends FunSuite with Matchers {
|}
|""".stripMargin)
}
test("Special const functions should work") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Intel8086, Cpu.Motorola6809)(
"""
| const array values = [111, if (0,1,2), if(1,2,3), min(2,3,4), max(2,3,4)
| pointer output @$c000
| void main() {
|
| }
""".stripMargin){m =>
val arrayStart = m.readWord(0xc000)
m.readByte(arrayStart + 1) should equal(2)
m.readByte(arrayStart + 2) should equal(2)
m.readByte(arrayStart + 3) should equal(2)
m.readByte(arrayStart + 4) should equal(4)
}
}
test("Do not compile const functions with variables") {
ShouldNotCompile(
"""
| byte output
| void main() {
| min(output,0)
| }
""".stripMargin)
}
}