1
0
mirror of https://github.com/KarolS/millfork.git synced 2025-01-01 06:29:53 +00:00

Const-pure functions

This commit is contained in:
Karol Stasiak 2020-03-19 23:53:16 +01:00
parent 5acf92d4e8
commit d478f3504f
8 changed files with 328 additions and 45 deletions

View File

@ -46,6 +46,16 @@ Examples:
Unlike hardware handlers with `interrupt`, you can treat functions with `kernal_interrupt` like normal functions.
On non-6502-based targets, functions marked as `kernal_interrupt` don't differ from normal functions.
* `const` the function is pure and can be used in constant expressions. `const` functions are not allowed to:
* use constants that have been declared after them
* have local variables
* call non-const functions
* contain any other statements other than return statements and conditional statements
* `<return_type>` is a valid return type, see [Types](./types.md)
* `<params>` is a comma-separated list of parameters, in form `type name`. Allowed types are the same as for local variables.

View File

@ -0,0 +1,132 @@
package millfork.env
import millfork.node.{Expression, FunctionCallExpression, GeneratedConstantExpression, IfStatement, IndexedExpression, LiteralExpression, ReturnStatement, Statement, SumExpression, VariableExpression}
/**
* @author Karol Stasiak
*/
object ConstPureFunctions {
def checkConstPure(env: Environment, function: NormalFunction): Unit = {
if (!function.isConstPure) return
val params = function.params match {
case NormalParamSignature(ps) => ps.map(p => p.name.stripPrefix(function.name + "$")).toSet
}
checkConstPure(env, function.code, params)
}
private def checkConstPure(env: Environment, s: List[Statement], params: Set[String]): Unit = {
s match {
case List(ReturnStatement(Some(expr))) => checkConstPure(env, expr, params)
case List(IfStatement(c, t, e)) =>
checkConstPure(env, c, params)
checkConstPure(env, t, params)
checkConstPure(env, e, params)
case List(IfStatement(c, t, e), bad) =>
checkConstPure(env, c, params)
checkConstPure(env, t, params)
checkConstPure(env, e, params)
bad match {
case ReturnStatement(None) =>
case _ =>
env.log.error(s"Statement ${bad} not allowed in const-pure functions", bad.position)
}
case ReturnStatement(Some(_)) :: bad :: xs =>
bad match {
case ReturnStatement(None) =>
case _ =>
env.log.error(s"Statement ${bad} not allowed in const-pure functions", bad.position)
}
checkConstPure(env, xs, params)
case IfStatement(c, t, Nil) :: e =>
checkConstPure(env, c, params)
checkConstPure(env, t, params)
checkConstPure(env, e, params)
case (bad@ReturnStatement(None)) :: xs =>
env.log.error("Returning without value not allowed in const-pure functions",
bad.position.orElse(xs.headOption.flatMap(_.position)))
checkConstPure(env, xs, params)
case bad :: xs =>
env.log.error(s"Statement $bad not allowed in const-pure functions", bad.position)
checkConstPure(env, xs, params)
case _ =>
}
}
private def checkConstPure(env: Environment, expr: Expression, params: Set[String]): Unit = {
expr match {
case VariableExpression(vname) =>
if (params(vname)) return
if (env.eval(expr).isDefined) return
env.log.error(s"Refering to `$vname` not allowed in const-pure functions", expr.position)
case LiteralExpression(_, _) =>
case GeneratedConstantExpression(_, _) =>
case SumExpression(expressions, _) =>
for((_, e) <- expressions) checkConstPure(env, e, params)
case FunctionCallExpression(functionName, expressions) =>
functionName match {
case "/" | "%%" | "*" | "*'" | "<<" | ">>" | "<<'" | ">>'" | ">>>>" =>
case ">" | ">=" | "<=" | "<" | "!=" | "==" =>
case "|" | "||" | "&" | "&&" | "^" =>
case f if Environment.constOnlyBuiltinFunction(f) =>
case f if Environment.predefinedFunctions(f) =>
case _ =>
env.maybeGet[Thing](functionName) match {
case Some(n: NormalFunction) if n.isConstPure =>
case Some(_: Type) =>
case Some(_: ConstOnlyCallable) =>
case Some(th) =>
env.log.error(s"Calling `${th.name}` not allowed in const-pure functions", expr.position)
case None =>
if (functionName.exists(c => Character.isAlphabetic(c.toInt))) {
env.log.error(s"Calling undefined thing `$functionName` not allowed in const-pure functions", expr.position)
} else {
env.log.error(s"Operator `$functionName` not allowed in const-pure functions", expr.position)
}
}
}
for (e <- expressions) checkConstPure(env, e, params)
case IndexedExpression(vname, index) =>
env.getPointy(vname) match {
case p: ConstantPointy if p.isArray && p.readOnly =>
case _ =>
env.log.error(s"Calling `${vname}` not allowed in const-pure functions", expr.position)
}
checkConstPure(env, index, params)
case _ =>
env.log.error(s"Expression not allowed in const-pure functions", expr.position)
}
}
def eval(env: Environment, function: NormalFunction, args: List[Constant]): Option[Constant] = {
val fitArgs = args.zip(function.params.types).map { case (arg, typ) => arg.fitInto(typ) }
val params = function.params match {
case NormalParamSignature(ps) => ps.zip(fitArgs).map { case (p, arg) => p.name.stripPrefix(function.name + "$") -> arg }.toMap
}
eval(env, function.code, params).map(_.fitInto(function.returnType))
}
@scala.annotation.tailrec
private def eval(env: Environment, code: List[Statement], params: Map[String, Constant]): Option[Constant] = {
code match {
case List(ReturnStatement(Some(expr))) =>
env.eval(expr, params)
case IfStatement(cond, t, Nil) :: xs =>
env.eval(cond, params) match {
case Some(c) if c.isProvablyZero => eval(env, xs, params)
case Some(NumericConstant(value, _)) => eval(env, if (value != 0) t else xs, params)
case _ => None
}
case List(IfStatement(cond, t, e)) =>
env.eval(cond, params) match {
case Some(c) if c.isProvablyZero => eval(env, e, params)
case Some(NumericConstant(value, _)) => eval(env, if (value != 0) t else e, params)
case _ => None
}
case _ => None
}
}
}

View File

@ -121,8 +121,20 @@ sealed trait Constant {
def fitInto(typ: Type): Constant = {
// TODO:
typ.size match {
case 1 => loByte
case 2 => subword(0)
case 1 =>
loByte.quickSimplify match {
case NumericConstant(value, 1) =>
if (typ.isSigned) NumericConstant(value.toByte, 1)
else NumericConstant(value & 0xff, 1)
case b => b
}
case 2 =>
subword(0).quickSimplify match {
case NumericConstant(value, _) =>
if (typ.isSigned) NumericConstant(value.toShort, 2)
else NumericConstant(value & 0xffff, 2)
case w => w
}
case _ => this
}
}

View File

@ -666,10 +666,10 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
case IndexedExpression(arrName, index) =>
getPointy(arrName) match {
case ConstantPointy(MemoryAddressConstant(arr:InitializedArray), _, _, _, _, _, _, _) if arr.readOnly && arr.elementType.size == 1 =>
eval(index).flatMap {
evalImpl(index, vv).flatMap {
case NumericConstant(constIndex, _) =>
if (constIndex >= 0 && constIndex < arr.sizeInBytes) {
eval(arr.contents(constIndex.toInt))
evalImpl(arr.contents(constIndex.toInt), vv)
} else None
case _ => None
}
@ -711,21 +711,21 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
case "hi" =>
if (params.size == 1) {
eval(params.head).map(_.hiByte.quickSimplify)
evalImpl(params.head, vv).map(_.hiByte.quickSimplify)
} else {
log.error("Invalid number of parameters for `hi`", e.position)
None
}
case "lo" =>
if (params.size == 1) {
eval(params.head).map(_.loByte.quickSimplify)
evalImpl(params.head, vv).map(_.loByte.quickSimplify)
} else {
log.error("Invalid number of parameters for `lo`", e.position)
None
}
case "sin" =>
if (params.size == 2) {
(eval(params(0)) -> eval(params(1))) match {
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
val value = (scale * math.sin(angle * math.Pi / 128)).round.toInt
Some(Constant(value))
@ -737,7 +737,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
case "cos" =>
if (params.size == 2) {
(eval(params(0)) -> eval(params(1))) match {
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
val value = (scale * math.cos(angle * math.Pi / 128)).round.toInt
Some(Constant(value))
@ -749,7 +749,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
case "tan" =>
if (params.size == 2) {
(eval(params(0)) -> eval(params(1))) match {
(evalImpl(params(0), vv) -> evalImpl(params(1), vv)) match {
case (Some(NumericConstant(angle, _)), Some(NumericConstant(scale, _))) =>
val value = (scale * math.tan(angle * math.Pi / 128)).round.toInt
Some(Constant(value))
@ -760,17 +760,17 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
None
}
case "min" =>
constantOperation(MathOperator.Minimum, fce)
constantOperation(MathOperator.Minimum, fce, vv)
case "max" =>
constantOperation(MathOperator.Maximum, fce)
constantOperation(MathOperator.Maximum, fce, vv)
case "if" =>
if (params.size == 3) {
eval(params(0)).map(_.quickSimplify) match {
case Some(NumericConstant(cond, _)) =>
eval(params(if (cond != 0) 1 else 2))
evalImpl(params(if (cond != 0) 1 else 2), vv)
case Some(c) =>
if (c.isProvablyGreaterOrEqualThan(1)) eval(params(1))
else if (c.isProvablyZero) eval(params(2))
if (c.isProvablyGreaterOrEqualThan(1)) evalImpl(params(1), vv)
else if (c.isProvablyZero) evalImpl(params(2), vv)
else None
case _ => None
}
@ -781,13 +781,13 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
case "nonet" =>
params match {
case List(FunctionCallExpression("<<", ps@List(_, _))) =>
constantOperation(MathOperator.Shl9, ps)
constantOperation(MathOperator.Shl9, ps, vv)
case List(FunctionCallExpression("<<'", ps@List(_, _))) =>
constantOperation(MathOperator.DecimalShl9, ps)
constantOperation(MathOperator.DecimalShl9, ps, vv)
case List(SumExpression(ps@List((false,_),(false,_)), false)) =>
constantOperation(MathOperator.Plus9, ps.map(_._2))
constantOperation(MathOperator.Plus9, ps.map(_._2), vv)
case List(SumExpression(ps@List((false,_),(false,_)), true)) =>
constantOperation(MathOperator.DecimalPlus9, ps.map(_._2))
constantOperation(MathOperator.DecimalPlus9, ps.map(_._2), vv)
case List(_) =>
None
case _ =>
@ -795,41 +795,59 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
None
}
case ">>'" =>
constantOperation(MathOperator.DecimalShr, params)
constantOperation(MathOperator.DecimalShr, params, vv)
case "<<'" =>
constantOperation(MathOperator.DecimalShl, params)
constantOperation(MathOperator.DecimalShl, params, vv)
case ">>" =>
constantOperation(MathOperator.Shr, params)
constantOperation(MathOperator.Shr, params, vv)
case "<<" =>
constantOperation(MathOperator.Shl, params)
constantOperation(MathOperator.Shl, params, vv)
case ">>>>" =>
constantOperation(MathOperator.Shr9, params)
constantOperation(MathOperator.Shr9, params, vv)
case "*'" =>
constantOperation(MathOperator.DecimalTimes, params)
constantOperation(MathOperator.DecimalTimes, params, vv)
case "*" =>
constantOperation(MathOperator.Times, params)
constantOperation(MathOperator.Times, params, vv)
case "/" =>
constantOperation(MathOperator.Divide, params)
constantOperation(MathOperator.Divide, params, vv)
case "%%" =>
constantOperation(MathOperator.Modulo, params)
constantOperation(MathOperator.Modulo, params, vv)
case "&&" | "&" =>
constantOperation(MathOperator.And, params)
constantOperation(MathOperator.And, params, vv)
case "^" =>
constantOperation(MathOperator.Exor, params)
constantOperation(MathOperator.Exor, params, vv)
case "||" | "|" =>
constantOperation(MathOperator.Or, params)
constantOperation(MathOperator.Or, params, vv)
case ">" => evalComparisons(params, vv, _ > _)
case "<" => evalComparisons(params, vv, _ < _)
case ">=" => evalComparisons(params, vv, _ >= _)
case "<=" => evalComparisons(params, vv, _ <= _)
case "==" => evalComparisons(params, vv, _ == _)
case "!=" =>
sequence(params.map(p => evalImpl(p, vv))) match {
case Some(List(NumericConstant(n1, _), NumericConstant(n2, _))) =>
Some(if (n1 != n2) Constant.One else Constant.Zero)
case _ => None
}
case _ =>
maybeGet[Type](name) match {
maybeGet[Thing](name) match {
case Some(t: StructType) =>
if (params.size == t.fields.size) {
sequence(params.map(eval)).map(fields => StructureConstant(t, fields.zip(t.fields).map{
sequence(params.map(p => evalImpl(p, vv))).map(fields => StructureConstant(t, fields.zip(t.fields).map{
case (fieldConst, fieldDesc) =>
fieldConst.fitInto(get[Type](fieldDesc.typeName))
}))
} else None
case Some(n: NormalFunction) if n.isConstPure =>
if (params.size == n.params.length) {
sequence(params.map(p => evalImpl(p, vv))) match {
case Some(args) => ConstPureFunctions.eval(this, n, args)
case _ => None
}
} else None
case Some(_: UnionType) =>
None
case Some(t) =>
case Some(t: Type) =>
if (params.size == 1) {
eval(params.head).map{ c =>
c.fitInto(t)
@ -860,17 +878,30 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
}
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression): Option[Constant] = {
private def evalComparisons(params: List[Expression], vv: Option[Map[String, Constant]], cond: (Long, Long) => Boolean): Option[Constant] = {
if (params.size < 2) return None
val numbers = sequence(params.map{ e =>
evalImpl(e, vv) match {
case Some(NumericConstant(n, _)) => Some(n)
case _ => None
}
})
numbers.map { ns =>
if (ns.init.zip(ns.tail).forall(cond.tupled)) Constant.One else Constant.Zero
}
}
private def constantOperation(op: MathOperator.Value, fce: FunctionCallExpression, vv: Option[Map[String, Constant]]): Option[Constant] = {
val params = fce.expressions
if (params.isEmpty) {
log.error(s"Invalid number of parameters for `${fce.functionName}`", fce.position)
None
}
constantOperation(op, fce.expressions)
constantOperation(op, fce.expressions, vv)
}
private def constantOperation(op: MathOperator.Value, params: List[Expression]): Option[Constant] = {
params.map(eval).reduceLeft[Option[Constant]] { (oc, om) =>
private def constantOperation(op: MathOperator.Value, params: List[Expression], vv: Option[Map[String, Constant]]): Option[Constant] = {
params.map(p => evalImpl(p, vv)).reduceLeft[Option[Constant]] { (oc, om) =>
for {
c <- oc
m <- om
@ -1252,11 +1283,11 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}
val paramForAutomaticReturn: List[Option[Expression]] = if (stmt.isMacro || stmt.assembly) {
Nil
} else if (statements.isEmpty) {
} else if (executableStatements.isEmpty) {
List(None)
} else {
statements.last match {
case _: ReturnStatement => Nil
executableStatements.last match {
case s if s.isValidFunctionEnd => Nil
case WhileStatement(VariableExpression(tr), _, _, _) =>
if (resultType.size > 0 && env.getBooleanConstant(tr).contains(true)) {
List(Some(LiteralExpression(0, 1))) // TODO: what if the loop is breakable?
@ -1299,12 +1330,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
interrupt = stmt.interrupt,
kernalInterrupt = stmt.kernalInterrupt,
reentrant = stmt.reentrant,
isConstPure = stmt.constPure,
position = stmt.position,
declaredBank = stmt.bank,
alignment = stmt.alignment.getOrElse(if (name == "main") NoAlignment else defaultFunctionAlignment(options, hot = true)) // TODO: decide actual hotness in a smarter way
)
addThing(mangled, stmt.position)
registerAddressConstant(mangled, stmt.position, options, None)
if (mangled.isConstPure) {
ConstPureFunctions.checkConstPure(env, mangled)
}
}
}
}
@ -1415,7 +1450,8 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
assembly = true,
interrupt = false,
kernalInterrupt = false,
reentrant = false
reentrant = false,
constPure = function.isConstPure
), options)
get[FunctionInMemory](function.name + ".trampoline")
}

View File

@ -375,6 +375,8 @@ sealed trait MangledFunction extends CallableThing {
def interrupt: Boolean
def isConstPure: Boolean
def canBePointedTo: Boolean
def requiresTrampoline(compilationOptions: CompilationOptions): Boolean = false
@ -387,6 +389,8 @@ case class EmptyFunction(name: String,
override def interrupt = false
override def isConstPure = false
override def canBePointedTo: Boolean = false
}
@ -397,6 +401,8 @@ case class MacroFunction(name: String,
code: List[ExecutableStatement]) extends MangledFunction {
override def interrupt = false
override def isConstPure = false
override def canBePointedTo: Boolean = false
}
@ -424,6 +430,8 @@ case class ExternFunction(name: String,
override def interrupt = false
override def isConstPure = false
override def zeropage: Boolean = false
override def isVolatile: Boolean = false
@ -439,6 +447,7 @@ case class NormalFunction(name: String,
hasElidedReturnVariable: Boolean,
interrupt: Boolean,
kernalInterrupt: Boolean,
isConstPure: Boolean,
reentrant: Boolean,
position: Option[Position],
declaredBank: Option[String],

View File

@ -4,7 +4,7 @@ import millfork.assembly.Elidability
import millfork.assembly.m6809.{MAddrMode, MOpcode}
import millfork.assembly.mos.opt.SourceOfNZ
import millfork.assembly.mos.{AddrMode, Opcode}
import millfork.assembly.z80.{ZOpcode, ZRegisters}
import millfork.assembly.z80.{NoRegisters, OneRegister, ZOpcode, ZRegisters}
import millfork.env.{Constant, ParamPassingConvention, Type, VariableType}
import millfork.output.MemoryAlignment
@ -608,6 +608,7 @@ case class FunctionDeclarationStatement(name: String,
assembly: Boolean,
interrupt: Boolean,
kernalInterrupt: Boolean,
constPure: Boolean,
reentrant: Boolean) extends BankedDeclarationStatement {
override def getAllExpressions: List[Expression] = address.toList ++ statements.getOrElse(Nil).flatMap(_.getAllExpressions)
@ -626,7 +627,9 @@ case class FunctionDeclarationStatement(name: String,
}
}
sealed trait ExecutableStatement extends Statement
sealed trait ExecutableStatement extends Statement {
def isValidFunctionEnd: Boolean = false
}
case class RawBytesStatement(contents: ArrayContents, bigEndian: Boolean) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = contents.getAllExpressions(bigEndian)
@ -646,10 +649,12 @@ case class ExpressionStatement(expression: Expression) extends ExecutableStateme
case class ReturnStatement(value: Option[Expression]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = value.toList
override def isValidFunctionEnd: Boolean = true
}
case class GotoStatement(target: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(target)
override def isValidFunctionEnd: Boolean = true
}
case class LabelStatement(name: String) extends ExecutableStatement {
@ -694,14 +699,24 @@ case class MosAssemblyStatement(opcode: Opcode.Value, addrMode: AddrMode.Value,
expression.getAllIdentifiers.toSeq.filter(i => !i.contains('.') || i.endsWith(".addr") || i.endsWith(".addr.lo")).map(_.takeWhile(_ != '.'))
case _ => Seq.empty
}
override def isValidFunctionEnd: Boolean = opcode == Opcode.RTS || opcode == Opcode.RTI || opcode == Opcode.JMP || opcode == Opcode.BRA
}
case class Z80AssemblyStatement(opcode: ZOpcode.Value, registers: ZRegisters, offsetExpression: Option[Expression], expression: Expression, elidability: Elidability.Value) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(expression)
override def isValidFunctionEnd: Boolean = registers match {
case NoRegisters | OneRegister(_) =>
opcode == ZOpcode.RETN || opcode == ZOpcode.RETI || opcode == ZOpcode.JP
case _ =>
false
}
}
case class M6809AssemblyStatement(opcode: MOpcode.Value, addrMode: MAddrMode, expression: Expression, elidability: Elidability.Value) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(expression)
override def isValidFunctionEnd: Boolean = opcode == MOpcode.RTS || opcode == MOpcode.JMP || opcode == MOpcode.BRA || opcode == MOpcode.RTI
}
case class IfStatement(condition: Expression, thenBranch: List[ExecutableStatement], elseBranch: List[ExecutableStatement]) extends CompoundStatement {
@ -717,6 +732,8 @@ case class IfStatement(condition: Expression, thenBranch: List[ExecutableStateme
}
override def loopVariable: String = "-none-"
override def isValidFunctionEnd: Boolean = thenBranch.lastOption.fold(false)(_.isValidFunctionEnd) && elseBranch.lastOption.fold(false)(_.isValidFunctionEnd)
}
case class WhileStatement(condition: Expression, body: List[ExecutableStatement], increment: List[ExecutableStatement], labels: Set[String] = Set("", "while")) extends CompoundStatement {

View File

@ -96,7 +96,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
val variableFlags: P[Set[String]] = flags_("const", "static", "volatile", "stack", "register")
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt")
val functionFlags: P[Set[String]] = flags_("asm", "inline", "interrupt", "macro", "noinline", "reentrant", "kernal_interrupt", "const")
val codec: P[((TextCodec, Boolean), Boolean)] = P(position("text codec identifier") ~ identifier.?.map(_.getOrElse(""))).map {
case (_, "" | "default") => (options.platform.defaultCodec -> false) -> options.flag(CompilationFlag.LenientTextEncoding)
@ -592,6 +592,11 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
if (flags("macro") && flags("noinline")) log.error("Noinline and macro exclude each other", Some(p))
if (flags("inline") && flags("macro")) log.error("Macro and inline exclude each other", Some(p))
if (flags("interrupt") && returnType != "void") log.error(s"Interrupt function `$name` has to return void", Some(p))
if (flags("const") && returnType == "void") log.error(s"Const-pure function `$name` cannot return void", Some(p))
if (flags("const") && flags("interrupt")) log.error(s"Const-pure function `$name` cannot be an interrupt", Some(p))
if (flags("const") && flags("kernal_interrupt")) log.error(s"Const-pure function `$name` cannot be a Kernal interrupt", Some(p))
if (flags("const") && flags("macro")) log.error(s"Const-pure function `$name` cannot be a macro", Some(p))
if (flags("const") && flags("asm")) log.error(s"Const-pure function `$name` cannot contain assembly", Some(p))
if (addr.isEmpty && statements.isEmpty) log.error(s"Extern function `$name` must have an address", Some(p))
if (addr.isDefined && alignment.isDefined) log.error(s"Function `$name` has both address and alignment", Some(p))
if (statements.isEmpty && alignment.isDefined) log.error(s"Extern function `$name` cannot have alignment", Some(p))
@ -607,6 +612,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri
flags("asm"),
flags("interrupt"),
flags("kernal_interrupt"),
flags("const") && !flags("asm"),
flags("reentrant")).pos(p))
}

View File

@ -2,7 +2,7 @@ package millfork.test
import millfork.Cpu
import millfork.env.{BasicPlainType, DerivedPlainType, NumericConstant}
import millfork.test.emu.{EmuUnoptimizedCrossPlatformRun, ShouldNotCompile}
import millfork.test.emu.{EmuUnoptimizedCrossPlatformRun, EmuUnoptimizedRun, ShouldNotCompile}
import org.scalatest.{FunSuite, Matchers}
/**
@ -69,4 +69,65 @@ class ConstantSuite extends FunSuite with Matchers {
| }
""".stripMargin)
}
test("Const-pure functions") {
val m = EmuUnoptimizedRun(
"""
| pointer output @$c000
|
| const byte twice(byte x) = x << 1
| const byte abs(sbyte x) {
| if x < 0 { return -x }
| else {return x }
| }
|
| const byte result = twice(30) + abs(-9)
| const array values = [112, twice(21), abs(-4), result]
|
| void main() {
| output = values.addr
| }
""".stripMargin)
val arrayStart = m.readWord(0xc000)
m.readByte(arrayStart + 1) should equal(42)
m.readByte(arrayStart + 2) should equal(4)
m.readByte(arrayStart + 3) should equal(69)
}
test("Const-pure Fibonacci") {
val m = EmuUnoptimizedRun(
"""
| pointer output @$c000
| byte output2 @ $c011
|
| const byte fib(byte x) {
| if x < 2 { return x }
| else {return fib(x-1) + fib(x-2) }
| }
|
| const array values = [for i,0,until,12 [fib(i)]]
|
| void main() {
| output = values.addr
| output2 = fib(11)
| }
|
""".stripMargin)
val arrayStart = m.readWord(0xc000)
m.readByte(arrayStart + 0) should equal(0)
m.readByte(arrayStart + 1) should equal(1)
m.readByte(arrayStart + 2) should equal(1)
m.readByte(arrayStart + 3) should equal(2)
m.readByte(arrayStart + 4) should equal(3)
m.readByte(arrayStart + 5) should equal(5)
m.readByte(arrayStart + 6) should equal(8)
m.readByte(arrayStart + 7) should equal(13)
m.readByte(arrayStart + 8) should equal(21)
m.readByte(arrayStart + 9) should equal(34)
m.readByte(arrayStart + 10) should equal(55)
m.readByte(arrayStart + 11) should equal(89)
m.readByte(0xc011) should equal(89)
}
}