mirror of
https://github.com/KarolS/millfork.git
synced 2025-01-01 06:29:53 +00:00
Const-pure functions
This commit is contained in:
parent
5acf92d4e8
commit
d478f3504f
@ -46,6 +46,16 @@ Examples:
|
||||
Unlike hardware handlers with `interrupt`, you can treat functions with `kernal_interrupt` like normal functions.
|
||||
On non-6502-based targets, functions marked as `kernal_interrupt` don't differ from normal functions.
|
||||
|
||||
* `const` – the function is pure and can be used in constant expressions. `const` functions are not allowed to:
|
||||
|
||||
* use constants that have been declared after them
|
||||
|
||||
* have local variables
|
||||
|
||||
* call non-const functions
|
||||
|
||||
* contain any other statements other than return statements and conditional statements
|
||||
|
||||
* `<return_type>` is a valid return type, see [Types](./types.md)
|
||||
|
||||
* `<params>` is a comma-separated list of parameters, in form `type name`. Allowed types are the same as for local variables.
|
||||
|
132
src/main/scala/millfork/env/ConstPureFunctions.scala
vendored
Normal file
132
src/main/scala/millfork/env/ConstPureFunctions.scala
vendored
Normal file
@ -0,0 +1,132 @@
|
||||
package millfork.env
|
||||
|
||||
import millfork.node.{Expression, FunctionCallExpression, GeneratedConstantExpression, IfStatement, IndexedExpression, LiteralExpression, ReturnStatement, Statement, SumExpression, VariableExpression}
|
||||
|
||||
/**
|
||||
* @author Karol Stasiak
|
||||
*/
|
||||
object ConstPureFunctions {
|
||||
|
||||
def checkConstPure(env: Environment, function: NormalFunction): Unit = {
|
||||
if (!function.isConstPure) return
|
||||
val params = function.params match {
|
||||
case NormalParamSignature(ps) => ps.map(p => p.name.stripPrefix(function.name + "$")).toSet
|
||||
}
|
||||
checkConstPure(env, function.code, params)
|
||||
}
|
||||
|
||||
private def checkConstPure(env: Environment, s: List[Statement], params: Set[String]): Unit = {
|
||||
s match {
|
||||
case List(ReturnStatement(Some(expr))) => checkConstPure(env, expr, params)
|
||||
case List(IfStatement(c, t, e)) =>
|
||||
checkConstPure(env, c, params)
|
||||
checkConstPure(env, t, params)
|
||||
checkConstPure(env, e, params)
|
||||
case List(IfStatement(c, t, e), bad) =>
|
||||
checkConstPure(env, c, params)
|
||||
checkConstPure(env, t, params)
|
||||
checkConstPure(env, e, params)
|
||||
bad match {
|
||||
case ReturnStatement(None) =>
|
||||
case _ =>
|
||||
env.log.error(s"Statement ${bad} not allowed in const-pure functions", bad.position)
|
||||
}
|
||||
case ReturnStatement(Some(_)) :: bad :: xs =>
|
||||
bad match {
|
||||
case ReturnStatement(None) =>
|
||||
case _ =>
|
||||
env.log.error(s"Statement ${bad} not allowed in const-pure functions", bad.position)
|
||||
}
|
||||
checkConstPure(env, xs, params)
|
||||
case IfStatement(c, t, Nil) :: e =>
|
||||
checkConstPure(env, c, params)
|
||||
checkConstPure(env, t, params)
|
||||
checkConstPure(env, e, params)
|
||||
case (bad@ReturnStatement(None)) :: xs =>
|
||||
env.log.error("Returning without value not allowed in const-pure functions",
|
||||
bad.position.orElse(xs.headOption.flatMap(_.position)))
|
||||
checkConstPure(env, xs, params)
|
||||
case bad :: xs =>
|
||||
env.log.error(s"Statement $bad not allowed in const-pure functions", bad.position)
|
||||
checkConstPure(env, xs, params)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
private def checkConstPure(env: Environment, expr: Expression, params: Set[String]): Unit = {
|
||||
expr match {
|
||||
case VariableExpression(vname) =>
|
||||
if (params(vname)) return
|
||||
if (env.eval(expr).isDefined) return
|
||||
env.log.error(s"Refering to `$vname` not allowed in const-pure functions", expr.position)
|
||||
case LiteralExpression(_, _) =>
|
||||
case GeneratedConstantExpression(_, _) =>
|
||||
case SumExpression(expressions, _) =>
|
||||
for((_, e) <- expressions) checkConstPure(env, e, params)
|
||||
case FunctionCallExpression(functionName, expressions) =>
|
||||
functionName match {
|
||||
case "/" | "%%" | "*" | "*'" | "<<" | ">>" | "<<'" | ">>'" | ">>>>" =>
|
||||
case ">" | ">=" | "<=" | "<" | "!=" | "==" =>
|
||||
case "|" | "||" | "&" | "&&" | "^" =>
|
||||
case f if Environment.constOnlyBuiltinFunction(f) =>
|
||||
case f if Environment.predefinedFunctions(f) =>
|
||||
case _ =>
|
||||
env.maybeGet[Thing](functionName) match {
|
||||
case Some(n: NormalFunction) if n.isConstPure =>
|
||||
case Some(_: Type) =>
|
||||
case Some(_: ConstOnlyCallable) =>
|
||||
case Some(th) =>
|
||||
env.log.error(s"Calling `${th.name}` not allowed in const-pure functions", expr.position)
|
||||
case None =>
|
||||
if (functionName.exists(c => Character.isAlphabetic(c.toInt))) {
|
||||
env.log.error(s"Calling undefined thing `$functionName` not allowed in const-pure functions", expr.position)
|
||||
} else {
|
||||
env.log.error(s"Operator `$functionName` not allowed in const-pure functions", expr.position)
|
||||
}
|
||||
}
|
||||
}
|
||||
for (e <- expressions) checkConstPure(env, e, params)
|
||||
case IndexedExpression(vname, index) =>
|
||||
env.getPointy(vname) match {
|
||||
case p: ConstantPointy if p.isArray && p.readOnly =>
|
||||
case _ =>
|
||||
env.log.error(s"Calling `${vname}` not allowed in const-pure functions", expr.position)
|
||||
}
|
||||
checkConstPure(env, index, params)
|
||||
case _ =>
|
||||
env.log.error(s"Expression not allowed in const-pure functions", expr.position)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def eval(env: Environment, function: NormalFunction, args: List[Constant]): Option[Constant] = {
|
||||
val fitArgs = args.zip(function.params.types).map { case (arg, typ) => arg.fitInto(typ) }
|
||||
val params = function.params match {
|
||||
case NormalParamSignature(ps) => ps.zip(fitArgs).map { case (p, arg) => p.name.stripPrefix(function.name + "$") -> arg }.toMap
|
||||
}
|
||||
eval(env, function.code, params).map(_.fitInto(function.returnType))
|
||||
}
|
||||
|
||||
@scala.annotation.tailrec
|
||||
private def eval(env: Environment, code: List[Statement], params: Map[String, Constant]): Option[Constant] = {
|
||||
code match {
|
||||
case List(ReturnStatement(Some(expr))) =>
|
||||
env.eval(expr, params)
|
||||
case IfStatement(cond, t, Nil) :: xs =>
|
||||
env.eval(cond, params) match {
|
||||
case Some(c) if c.isProvablyZero => eval(env, xs, params)
|
||||
case Some(NumericConstant(value, _)) => eval(env, if (value != 0) t else xs, params)
|
||||
case _ => None
|
||||
}
|
||||
case List(IfStatement(cond, t, e)) =>
|
||||
env.eval(cond, params) match {
|
||||
case Some(c) if c.isProvablyZero => eval(env, e, params)
|
||||
case Some(NumericConstant(value, _)) => eval(env, if (value != 0) t else e, params)
|
||||
case _ => None
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
}
|
16
src/main/scala/millfork/env/Constant.scala
vendored
16
src/main/scala/millfork/env/Constant.scala
vendored
@ -121,8 +121,20 @@ sealed trait Constant {
|
||||
def fitInto(typ: Type): Constant = {
|
||||
// TODO:
|
||||
typ.size match {
|
||||
case 1 => loByte
|
||||
case 2 => subword(0)
|
||||
case 1 =>
|
||||
loByte.quickSimplify match {
|
||||
case NumericConstant(value, 1) =>
|
||||
if (typ.isSigned) NumericConstant(value.toByte, 1)
|
||||
else NumericConstant(value & 0xff, 1)
|
||||
case b => b
|
||||
}
|
||||
case 2 =>
|
||||
subword(0).quickSimplify match {
|
||||
case NumericConstant(value, _) =>
|
||||
if (typ.isSigned) NumericConstant(value.toShort, 2)
|
||||
else NumericConstant(value & 0xffff, 2)
|
||||
case w => w
|
||||
}
|
||||
case _ => this
|
||||
}
|
||||
}
|
||||
|
114
src/main/scala/millfork/env/Environment.scala
vendored
114
src/main/scala/millfork/env/Environment.scala
vendored
@ -666,10 +666,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
case IndexedExpression(arrName, index) =>
|
||||
getPointy(arrName) match {
|
||||
case ConstantPointy(MemoryAddressConstant(arr:InitializedArray), _, _, _, _, _, _, _) if arr.readOnly && arr.elementType.size == 1 =>
|
||||
eval(index).flatMap {
|
||||
evalImpl(index, vv).flatMap {
|
||||
case NumericConstant(constIndex, _) =>
|
||||
if (constIndex >= 0 && constIndex < arr.sizeInBytes) {
|
||||
eval(arr.contents(constIndex.toInt))
|
||||
evalImpl(arr.contents(constIndex.toInt), vv)
|
||||
} else None
|
||||
case _ => None
|
||||
}
|
||||
@ -711,21 +711,21 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
case "hi" =>
|
||||
if (params.size == 1) {
|
||||
eval(params.head).map(_.hiByte.quickSimplify)
|
||||
evalImpl(params.head, vv).map(_.hiByte.quickSimplify)
|
||||
} else {
|
||||
log.error("Invalid number of parameters for `hi`", e.position)
|
||||
None
|
||||
}
|
||||
case "lo" =>
|
||||
if (params.size == 1) {
|
||||
eval(params.head).map(_.loByte.quickSimplify)
|
||||
evalImpl(params.head, vv).map(_.loByte.quickSimplify)
|
||||
} else {
|
||||
log.error("Invalid number of parameters for `lo`", e.position)
|
||||
None
|
||||
}
|
||||
case "sin" =>
|
||||
if (params.size == 2) {
|
||||
(eval(params(0)) -> eval(params(1))) match {
|
||||
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
|
||||
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
|
||||
val value = (scale * math.sin(angle * math.Pi / 128)).round.toInt
|
||||
Some(Constant(value))
|
||||
@ -737,7 +737,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
case "cos" =>
|
||||
if (params.size == 2) {
|
||||
(eval(params(0)) -> eval(params(1))) match {
|
||||
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
|
||||
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
|
||||
val value = (scale * math.cos(angle * math.Pi / 128)).round.toInt
|
||||
Some(Constant(value))
|
||||
@ -749,7 +749,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
case "tan" =>
|
||||
if (params.size == 2) {
|
||||
(eval(params(0)) -> eval(params(1))) match {
|
||||
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
|
||||
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
|
||||
val value = (scale * math.tan(angle * math.Pi / 128)).round.toInt
|
||||
Some(Constant(value))
|
||||
@ -760,17 +760,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
None
|
||||
}
|
||||
case "min" =>
|
||||
constantOperation(MathOperator.Minimum, fce)
|
||||
constantOperation(MathOperator.Minimum, fce, vv)
|
||||
case "max" =>
|
||||
constantOperation(MathOperator.Maximum, fce)
|
||||
constantOperation(MathOperator.Maximum, fce, vv)
|
||||
case "if" =>
|
||||
if (params.size == 3) {
|
||||
eval(params(0)).map(_.quickSimplify) match {
|
||||
case Some(NumericConstant(cond, _)) =>
|
||||
eval(params(if (cond != 0) 1 else 2))
|
||||
evalImpl(params(if (cond != 0) 1 else 2), vv)
|
||||
case Some(c) =>
|
||||
if (c.isProvablyGreaterOrEqualThan(1)) eval(params(1))
|
||||
else if (c.isProvablyZero) eval(params(2))
|
||||
if (c.isProvablyGreaterOrEqualThan(1)) evalImpl(params(1), vv)
|
||||
else if (c.isProvablyZero) evalImpl(params(2), vv)
|
||||
else None
|
||||
case _ => None
|
||||
}
|
||||
@ -781,13 +781,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
case "nonet" =>
|
||||
params match {
|
||||
case List(FunctionCallExpression("<<", ps@List(_, _))) =>
|
||||
constantOperation(MathOperator.Shl9, ps)
|
||||
constantOperation(MathOperator.Shl9, ps, vv)
|
||||
case List(FunctionCallExpression("<<'", ps@List(_, _))) =>
|
||||
constantOperation(MathOperator.DecimalShl9, ps)
|
||||
constantOperation(MathOperator.DecimalShl9, ps, vv)
|
||||
case List(SumExpression(ps@List((false,_),(false,_)), false)) =>
|
||||
constantOperation(MathOperator.Plus9, ps.map(_._2))
|
||||
constantOperation(MathOperator.Plus9, ps.map(_._2), vv)
|
||||
case List(SumExpression(ps@List((false,_),(false,_)), true)) =>
|
||||
constantOperation(MathOperator.DecimalPlus9, ps.map(_._2))
|
||||
constantOperation(MathOperator.DecimalPlus9, ps.map(_._2), vv)
|
||||
case List(_) =>
|
||||
None
|
||||
case _ =>
|
||||
@ -795,41 +795,59 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
None
|
||||
}
|
||||
case ">>'" =>
|
||||
constantOperation(MathOperator.DecimalShr, params)
|
||||
constantOperation(MathOperator.DecimalShr, params, vv)
|
||||
case "<<'" =>
|
||||
constantOperation(MathOperator.DecimalShl, params)
|
||||
constantOperation(MathOperator.DecimalShl, params, vv)
|
||||
case ">>" =>
|
||||
constantOperation(MathOperator.Shr, params)
|
||||
constantOperation(MathOperator.Shr, params, vv)
|
||||
case "<<" =>
|
||||
constantOperation(MathOperator.Shl, params)
|
||||
constantOperation(MathOperator.Shl, params, vv)
|
||||
case ">>>>" =>
|
||||
constantOperation(MathOperator.Shr9, params)
|
||||
constantOperation(MathOperator.Shr9, params, vv)
|
||||
case "*'" =>
|
||||
constantOperation(MathOperator.DecimalTimes, params)
|
||||
constantOperation(MathOperator.DecimalTimes, params, vv)
|
||||
case "*" =>
|
||||
constantOperation(MathOperator.Times, params)
|
||||
constantOperation(MathOperator.Times, params, vv)
|
||||
case "/" =>
|
||||
constantOperation(MathOperator.Divide, params)
|
||||
constantOperation(MathOperator.Divide, params, vv)
|
||||
case "%%" =>
|
||||
constantOperation(MathOperator.Modulo, params)
|
||||
constantOperation(MathOperator.Modulo, params, vv)
|
||||
case "&&" | "&" =>
|
||||
constantOperation(MathOperator.And, params)
|
||||
constantOperation(MathOperator.And, params, vv)
|
||||
case "^" =>
|
||||
constantOperation(MathOperator.Exor, params)
|
||||
constantOperation(MathOperator.Exor, params, vv)
|
||||
case "||" | "|" =>
|
||||
constantOperation(MathOperator.Or, params)
|
||||
constantOperation(MathOperator.Or, params, vv)
|
||||
case ">" => evalComparisons(params, vv, _ > _)
|
||||
case "<" => evalComparisons(params, vv, _ < _)
|
||||
case ">=" => evalComparisons(params, vv, _ >= _)
|
||||
case "<=" => evalComparisons(params, vv, _ <= _)
|
||||
case "==" => evalComparisons(params, vv, _ == _)
|
||||
case "!=" =>
|
||||
sequence(params.map(p => evalImpl(p, vv))) match {
|
||||
case Some(List(NumericConstant(n1, _), NumericConstant(n2, _))) =>
|
||||
Some(if (n1 != n2) Constant.One else Constant.Zero)
|
||||
case _ => None
|
||||
}
|
||||
case _ =>
|
||||
maybeGet[Type](name) match {
|
||||
maybeGet[Thing](name) match {
|
||||
case Some(t: StructType) =>
|
||||
if (params.size == t.fields.size) {
|
||||
sequence(params.map(eval)).map(fields => StructureConstant(t, fields.zip(t.fields).map{
|
||||
sequence(params.map(p => evalImpl(p, vv))).map(fields => StructureConstant(t, fields.zip(t.fields).map{
|
||||
case (fieldConst, fieldDesc) =>
|
||||
fieldConst.fitInto(get[Type](fieldDesc.typeName))
|
||||
}))
|
||||
} else None
|
||||
case Some(n: NormalFunction) if n.isConstPure =>
|
||||
if (params.size == n.params.length) {
|
||||
sequence(params.map(p => evalImpl(p, vv))) match {
|
||||
case Some(args) => ConstPureFunctions.eval(this, n, args)
|
||||
case _ => None
|
||||
}
|
||||
} else None
|
||||
case Some(_: UnionType) =>
|
||||
None
|
||||
case Some(t) =>
|
||||
case Some(t: Type) =>
|
||||
if (params.size == 1) {
|
||||
eval(params.head).map{ c =>
|
||||
c.fitInto(t)
|
||||
@ -860,17 +878,30 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
}
|
||||
|
||||
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression): Option[Constant] = {
|
||||
private def evalComparisons(params: List[Expression], vv: Option[Map[String, Constant]], cond: (Long, Long) => Boolean): Option[Constant] = {
|
||||
if (params.size < 2) return None
|
||||
val numbers = sequence(params.map{ e =>
|
||||
evalImpl(e, vv) match {
|
||||
case Some(NumericConstant(n, _)) => Some(n)
|
||||
case _ => None
|
||||
}
|
||||
})
|
||||
numbers.map { ns =>
|
||||
if (ns.init.zip(ns.tail).forall(cond.tupled)) Constant.One else Constant.Zero
|
||||
}
|
||||
}
|
||||
|
||||
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression, vv: Option[Map[String, Constant]]): Option[Constant] = {
|
||||
val params = fce.expressions
|
||||
if (params.isEmpty) {
|
||||
log.error(s"Invalid number of parameters for `${fce.functionName}`", fce.position)
|
||||
None
|
||||
}
|
||||
constantOperation(op, fce.expressions)
|
||||
constantOperation(op, fce.expressions, vv)
|
||||
}
|
||||
|
||||
private def constantOperation(op: MathOperator.Value, params: List[Expression]): Option[Constant] = {
|
||||
params.map(eval).reduceLeft[Option[Constant]] { (oc, om) =>
|
||||
private def constantOperation(op: MathOperator.Value, params: List[Expression], vv: Option[Map[String, Constant]]): Option[Constant] = {
|
||||
params.map(p => evalImpl(p, vv)).reduceLeft[Option[Constant]] { (oc, om) =>
|
||||
for {
|
||||
c <- oc
|
||||
m <- om
|
||||
@ -1252,11 +1283,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
}
|
||||
val paramForAutomaticReturn: List[Option[Expression]] = if (stmt.isMacro || stmt.assembly) {
|
||||
Nil
|
||||
} else if (statements.isEmpty) {
|
||||
} else if (executableStatements.isEmpty) {
|
||||
List(None)
|
||||
} else {
|
||||
statements.last match {
|
||||
case _: ReturnStatement => Nil
|
||||
executableStatements.last match {
|
||||
case s if s.isValidFunctionEnd => Nil
|
||||
case WhileStatement(VariableExpression(tr), _, _, _) =>
|
||||
if (resultType.size > 0 && env.getBooleanConstant(tr).contains(true)) {
|
||||
List(Some(LiteralExpression(0, 1))) // TODO: what if the loop is breakable?
|
||||
@ -1299,12 +1330,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
interrupt = stmt.interrupt,
|
||||
kernalInterrupt = stmt.kernalInterrupt,
|
||||
reentrant = stmt.reentrant,
|
||||
isConstPure = stmt.constPure,
|
||||
position = stmt.position,
|
||||
declaredBank = stmt.bank,
|
||||
alignment = stmt.alignment.getOrElse(if (name == "main") NoAlignment else defaultFunctionAlignment(options, hot = true)) // TODO: decide actual hotness in a smarter way
|
||||
)
|
||||
addThing(mangled, stmt.position)
|
||||
registerAddressConstant(mangled, stmt.position, options, None)
|
||||
if (mangled.isConstPure) {
|
||||
ConstPureFunctions.checkConstPure(env, mangled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1415,7 +1450,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
|
||||
assembly = true,
|
||||
interrupt = false,
|
||||
kernalInterrupt = false,
|
||||
reentrant = false
|
||||
reentrant = false,
|
||||
constPure = function.isConstPure
|
||||
), options)
|
||||
get[FunctionInMemory](function.name + ".trampoline")
|
||||
}
|
||||
|
9
src/main/scala/millfork/env/Thing.scala
vendored
9
src/main/scala/millfork/env/Thing.scala
vendored
@ -375,6 +375,8 @@ sealed trait MangledFunction extends CallableThing {
|
||||
|
||||
def interrupt: Boolean
|
||||
|
||||
def isConstPure: Boolean
|
||||
|
||||
def canBePointedTo: Boolean
|
||||
|
||||
def requiresTrampoline(compilationOptions: CompilationOptions): Boolean = false
|
||||
@ -387,6 +389,8 @@ case class EmptyFunction(name: String,
|
||||
|
||||
override def interrupt = false
|
||||
|
||||
override def isConstPure = false
|
||||
|
||||
override def canBePointedTo: Boolean = false
|
||||
}
|
||||
|
||||
@ -397,6 +401,8 @@ case class MacroFunction(name: String,
|
||||
code: List[ExecutableStatement]) extends MangledFunction {
|
||||
override def interrupt = false
|
||||
|
||||
override def isConstPure = false
|
||||
|
||||
override def canBePointedTo: Boolean = false
|
||||
}
|
||||
|
||||
@ -424,6 +430,8 @@ case class ExternFunction(name: String,
|
||||
|
||||
override def interrupt = false
|
||||
|
||||
override def isConstPure = false
|
||||
|
||||
override def zeropage: Boolean = false
|
||||
|
||||
override def isVolatile: Boolean = false
|
||||
@ -439,6 +447,7 @@ case class NormalFunction(name: String,
|
||||
hasElidedReturnVariable: Boolean,
|
||||
interrupt: Boolean,
|
||||
kernalInterrupt: Boolean,
|
||||
isConstPure: Boolean,
|
||||
reentrant: Boolean,
|
||||
position: Option[Position],
|
||||
declaredBank: Option[String],
|
||||
|
@ -4,7 +4,7 @@ import millfork.assembly.Elidability
|
||||
import millfork.assembly.m6809.{MAddrMode, MOpcode}
|
||||
import millfork.assembly.mos.opt.SourceOfNZ
|
||||
import millfork.assembly.mos.{AddrMode, Opcode}
|
||||
import millfork.assembly.z80.{ZOpcode, ZRegisters}
|
||||
import millfork.assembly.z80.{NoRegisters, OneRegister, ZOpcode, ZRegisters}
|
||||
import millfork.env.{Constant, ParamPassingConvention, Type, VariableType}
|
||||
import millfork.output.MemoryAlignment
|
||||
|
||||
@ -608,6 +608,7 @@ case class FunctionDeclarationStatement(name: String,
|
||||
assembly: Boolean,
|
||||
interrupt: Boolean,
|
||||
kernalInterrupt: Boolean,
|
||||
constPure: Boolean,
|
||||
reentrant: Boolean) extends BankedDeclarationStatement {
|
||||
override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions)
|
||||
|
||||
@ -626,7 +627,9 @@ case class FunctionDeclarationStatement(name: String,
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait ExecutableStatement extends Statement
|
||||
sealed trait ExecutableStatement extends Statement {
|
||||
def isValidFunctionEnd: Boolean = false
|
||||
}
|
||||
|
||||
case class RawBytesStatement(contents: ArrayContents, bigEndian: Boolean) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = contents.getAllExpressions(bigEndian)
|
||||
@ -646,10 +649,12 @@ case class ExpressionStatement(expression: Expression) extends ExecutableStateme
|
||||
|
||||
case class ReturnStatement(value: Option[Expression]) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = value.toList
|
||||
override def isValidFunctionEnd: Boolean = true
|
||||
}
|
||||
|
||||
case class GotoStatement(target: Expression) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = List(target)
|
||||
override def isValidFunctionEnd: Boolean = true
|
||||
}
|
||||
|
||||
case class LabelStatement(name: String) extends ExecutableStatement {
|
||||
@ -694,14 +699,24 @@ case class MosAssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value,
|
||||
expression.getAllIdentifiers.toSeq.filter(i => !i.contains('.') || i.endsWith(".addr") || i.endsWith(".addr.lo")).map(_.takeWhile(_ != '.'))
|
||||
case _ => Seq.empty
|
||||
}
|
||||
override def isValidFunctionEnd: Boolean = opcode == Opcode.RTS || opcode == Opcode.RTI || opcode == Opcode.JMP || opcode == Opcode.BRA
|
||||
}
|
||||
|
||||
case class Z80AssemblyStatement(opcode: ZOpcode.Value, registers: ZRegisters, offsetExpression: Option[Expression], expression: Expression, elidability: Elidability.Value) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = List(expression)
|
||||
|
||||
override def isValidFunctionEnd: Boolean = registers match {
|
||||
case NoRegisters | OneRegister(_) =>
|
||||
opcode == ZOpcode.RETN || opcode == ZOpcode.RETI || opcode == ZOpcode.JP
|
||||
case _ =>
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
case class M6809AssemblyStatement(opcode: MOpcode.Value, addrMode: MAddrMode, expression: Expression, elidability: Elidability.Value) extends ExecutableStatement {
|
||||
override def getAllExpressions: List[Expression] = List(expression)
|
||||
|
||||
override def isValidFunctionEnd: Boolean = opcode == MOpcode.RTS || opcode == MOpcode.JMP || opcode == MOpcode.BRA || opcode == MOpcode.RTI
|
||||
}
|
||||
|
||||
case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends CompoundStatement {
|
||||
@ -717,6 +732,8 @@ case class IfStatement(condition: Expression, thenBranch: List[ExecutableStateme
|
||||
}
|
||||
|
||||
override def loopVariable: String = "-none-"
|
||||
|
||||
override def isValidFunctionEnd: Boolean = thenBranch.lastOption.fold(false)(_.isValidFunctionEnd) && elseBranch.lastOption.fold(false)(_.isValidFunctionEnd)
|
||||
}
|
||||
|
||||
case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends CompoundStatement {
|
||||
|
@ -96,7 +96,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
|
||||
|
||||
val variableFlags: P[Set[String]] = flags_("const", "static", "volatile", "stack", "register")
|
||||
|
||||
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt")
|
||||
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt", "const")
|
||||
|
||||
val codec: P[((TextCodec, Boolean), Boolean)] = P(position("text codec identifier") ~ identifier.?.map(_.getOrElse(""))).map {
|
||||
case (_, "" | "default") => (options.platform.defaultCodec -> false) -> options.flag(CompilationFlag.LenientTextEncoding)
|
||||
@ -592,6 +592,11 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
|
||||
if (flags("macro") && flags("noinline")) log.error("Noinline and macro exclude each other", Some(p))
|
||||
if (flags("inline") && flags("macro")) log.error("Macro and inline exclude each other", Some(p))
|
||||
if (flags("interrupt") && returnType != "void") log.error(s"Interrupt function `$name` has to return void", Some(p))
|
||||
if (flags("const") && returnType == "void") log.error(s"Const-pure function `$name` cannot return void", Some(p))
|
||||
if (flags("const") && flags("interrupt")) log.error(s"Const-pure function `$name` cannot be an interrupt", Some(p))
|
||||
if (flags("const") && flags("kernal_interrupt")) log.error(s"Const-pure function `$name` cannot be a Kernal interrupt", Some(p))
|
||||
if (flags("const") && flags("macro")) log.error(s"Const-pure function `$name` cannot be a macro", Some(p))
|
||||
if (flags("const") && flags("asm")) log.error(s"Const-pure function `$name` cannot contain assembly", Some(p))
|
||||
if (addr.isEmpty && statements.isEmpty) log.error(s"Extern function `$name` must have an address", Some(p))
|
||||
if (addr.isDefined && alignment.isDefined) log.error(s"Function `$name` has both address and alignment", Some(p))
|
||||
if (statements.isEmpty && alignment.isDefined) log.error(s"Extern function `$name` cannot have alignment", Some(p))
|
||||
@ -607,6 +612,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
|
||||
flags("asm"),
|
||||
flags("interrupt"),
|
||||
flags("kernal_interrupt"),
|
||||
flags("const") && !flags("asm"),
|
||||
flags("reentrant")).pos(p))
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@ package millfork.test
|
||||
|
||||
import millfork.Cpu
|
||||
import millfork.env.{BasicPlainType, DerivedPlainType, NumericConstant}
|
||||
import millfork.test.emu.{EmuUnoptimizedCrossPlatformRun, ShouldNotCompile}
|
||||
import millfork.test.emu.{EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun, ShouldNotCompile}
|
||||
import org.scalatest.{FunSuite, Matchers}
|
||||
|
||||
/**
|
||||
@ -69,4 +69,65 @@ class ConstantSuite extends FunSuite with Matchers {
|
||||
| }
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
test("Const-pure functions") {
|
||||
val m = EmuUnoptimizedRun(
|
||||
"""
|
||||
| pointer output @$c000
|
||||
|
|
||||
| const byte twice(byte x) = x << 1
|
||||
| const byte abs(sbyte x) {
|
||||
| if x < 0 { return -x }
|
||||
| else {return x }
|
||||
| }
|
||||
|
|
||||
| const byte result = twice(30) + abs(-9)
|
||||
| const array values = [112, twice(21), abs(-4), result]
|
||||
|
|
||||
| void main() {
|
||||
| output = values.addr
|
||||
| }
|
||||
""".stripMargin)
|
||||
val arrayStart = m.readWord(0xc000)
|
||||
m.readByte(arrayStart + 1) should equal(42)
|
||||
m.readByte(arrayStart + 2) should equal(4)
|
||||
m.readByte(arrayStart + 3) should equal(69)
|
||||
|
||||
}
|
||||
|
||||
test("Const-pure Fibonacci") {
|
||||
val m = EmuUnoptimizedRun(
|
||||
"""
|
||||
| pointer output @$c000
|
||||
| byte output2 @ $c011
|
||||
|
|
||||
| const byte fib(byte x) {
|
||||
| if x < 2 { return x }
|
||||
| else {return fib(x-1) + fib(x-2) }
|
||||
| }
|
||||
|
|
||||
| const array values = [for i,0,until,12 [fib(i)]]
|
||||
|
|
||||
| void main() {
|
||||
| output = values.addr
|
||||
| output2 = fib(11)
|
||||
| }
|
||||
|
|
||||
""".stripMargin)
|
||||
val arrayStart = m.readWord(0xc000)
|
||||
m.readByte(arrayStart + 0) should equal(0)
|
||||
m.readByte(arrayStart + 1) should equal(1)
|
||||
m.readByte(arrayStart + 2) should equal(1)
|
||||
m.readByte(arrayStart + 3) should equal(2)
|
||||
m.readByte(arrayStart + 4) should equal(3)
|
||||
m.readByte(arrayStart + 5) should equal(5)
|
||||
m.readByte(arrayStart + 6) should equal(8)
|
||||
m.readByte(arrayStart + 7) should equal(13)
|
||||
m.readByte(arrayStart + 8) should equal(21)
|
||||
m.readByte(arrayStart + 9) should equal(34)
|
||||
m.readByte(arrayStart + 10) should equal(55)
|
||||
m.readByte(arrayStart + 11) should equal(89)
|
||||
m.readByte(0xc011) should equal(89)
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user