1
0
mirror of https://github.com/KarolS/millfork.git synced 2024-06-30 06:29:31 +00:00

Added return dispatch statements.

This commit is contained in:
Karol Stasiak 2018-01-30 17:38:32 +01:00
parent ac51bcaf6c
commit c26d36f974
21 changed files with 485 additions and 30 deletions

13
CHANGELOG.md Normal file
View File

@ -0,0 +1,13 @@
#Change log
##Current version
* Added return dispatch statements.
* Fixed several optimization bugs.
* Other minor improvements.
## 0.1
* Initial numbered version.

View File

@ -12,6 +12,12 @@ even up to hardware damage.
* reading uninitialized variables: will return undefined values * reading uninitialized variables: will return undefined values
* reading variables used by return dispatch statements but not assigned a value: will return undefined values
* returning a value from a function by return dispatch to a function of different return type: will return undefined values
* passing an index out of range for a return dispatch statement
* stack overflow: exhausting the hardware stack due to excess recursion, excess function calls or excess stack-allocated variables * stack overflow: exhausting the hardware stack due to excess recursion, excess function calls or excess stack-allocated variables
* on ROM-based platforms: writing to arrays * on ROM-based platforms: writing to arrays

View File

@ -31,17 +31,28 @@
## Code generation options ## Code generation options
* `-fcmos-ops`, `-fno-cmos-ops` Whether should emit CMOS opcodes. `.ini` equivalent: `emit_cmos`. * `-fcmos-ops`, `-fno-cmos-ops` Whether should emit CMOS opcodes.
`.ini` equivalent: `emit_cmos`. Default: yes if targeting 65C02, no otherwise.
* `-fillegals`, `-fno-illegals` Whether should emit illegal (undocumented) NMOS opcodes. `.ini` equivalent: `emit_illegals`. * `-fillegals`, `-fno-illegals` Whether should emit illegal (undocumented) NMOS opcodes.
`.ini` equivalent: `emit_illegals`. Default: no.
* `-fjmp-fix`, `-fno-jmp-fix` Whether should prevent indirect JMP bug on page boundary. `.ini` equivalent: `prevent_jmp_indirect_bug`. * `-fjmp-fix`, `-fno-jmp-fix` Whether should prevent indirect JMP bug on page boundary.
`.ini` equivalent: `prevent_jmp_indirect_bug`. Default: no if targeting 65C02, yes otherwise.
* `-fdecimal-mode`, `-fno-decimal-mode` Whether decimal mode should be available. `.ini` equivalent: `decimal_mode`. * `-fdecimal-mode`, `-fno-decimal-mode` Whether decimal mode should be available.
`.ini` equivalent: `decimal_mode`. Default: no if targeting Ricoh, yes otherwise.
* `-fvariable-overlap`, `-fno-variable-overlap` Whether variables should overlap if their scopes do not intersect. Default: yes. * `-fvariable-overlap`, `-fno-variable-overlap` Whether variables should overlap if their scopes do not intersect.
Default: yes.
* `-fbounds-checking`, `-fnobounds-checking` Whether should insert bounds checking on array access. Default: no. * `-fbounds-checking`, `-fnobounds-checking` Whether should insert bounds checking on array access.
Default: no.
* `-fcompact-dispatch-params`, `-fnocompact-dispatch-params`
Whether parameter values in return dispatch statements may overlap other objects.
This may cause problems if the parameter table is stored next to a hardware register that has side effects when reading.
`.ini` equivalent: `compact_dispatch_params`. Default: yes.
## Optimization options ## Optimization options

View File

@ -79,6 +79,72 @@ if <expression> {
} }
``` ```
### `return` statement
Syntax:
```
return
```
```
return <expression>
```
### `return[]` statement (return dispatch)
Syntax examples:
```
return [a + b] {
0 @ underflow
255 @ overflow
default @ nothing
}
```
```
return [getF()] {
1 @ function1
2 @ function2
default(5) @ functionDefault
}
```
```
return [i] (param1, param2) {
1,5,8 @ function1(4, 6)
2 @ function2(9)
default(0,20) @ functionDefault
}
```
Return dispatch calculates the value of an index, picks the correct branch,
assigns some global variables and jumps to another function.
The index has to evaluate to a byte. The functions cannot be `inline` and shouldn't have parameters.
Jumping to a function with parameters gives those parameters undefined values.
The functions are not called, so they don't return to the function the return dispatch statement is in, but to its caller.
The return values are passed along. If the dispatching function has a non-`void` return type different that the type
of the function dispatched to, the return value is undefined.
If the `default` branch exists, then it is used for every missing index value between other supported values.
Optional parameters to `default` specify the maximum, or both the minimum and maximum supported index value.
In the above examples: the first example supports values 0255, second 15, and third 020.
If the index has an unsupported value, the behaviour is formally undefined, but in practice the program will simply crash.
Before jumping to the function, the chosen global variables will be assigned parameter values.
Variables have to be global byte-sized. Some simple array indexing expressions are also allowed.
Parameter values have to be constants.
For example, in the third example one of the following will happen:
* if `i` is 1, 5 or 8, then `param1` is assigned 4, `param2` is assigned 6 and then `function1` is called;
* if `i` is 2, then `param1` is assigned 9, `param2` is assigned an undefined value and then `function2` is called;
* if `i` is any other value from 0 to 20, then `param1` and `param2` are assigned undefined values and then `functionDefault` is called;
* if `i` has any other value, then undefined behaviour.
### `while` and `do-while` statements ### `while` and `do-while` statements
Syntax: Syntax:

View File

@ -5,7 +5,7 @@ import millfork.error.ErrorReporting
/** /**
* @author Karol Stasiak * @author Karol Stasiak
*/ */
class CompilationOptions(val platform: Platform, val commandLineFlags: Map[CompilationFlag.Value, Boolean]) { case class CompilationOptions(platform: Platform, commandLineFlags: Map[CompilationFlag.Value, Boolean]) {
import CompilationFlag._ import CompilationFlag._
import Cpu._ import Cpu._
@ -46,11 +46,11 @@ object Cpu extends Enumeration {
import CompilationFlag._ import CompilationFlag._
def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match { def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match {
case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams)
case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams)
case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap) case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams)
case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap) case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams)
case Cmos => Set(EmitCmosOpcodes, VariableOverlap) case Cmos => Set(EmitCmosOpcodes, VariableOverlap, CompactReturnDispatchParams)
} }
def fromString(name: String): Cpu.Value = name match { def fromString(name: String): Cpu.Value = name match {
@ -77,7 +77,7 @@ object CompilationFlag extends Enumeration {
// optimization options: // optimization options:
DetailedFlowAnalysis, DangerousOptimizations, InlineFunctions, DetailedFlowAnalysis, DangerousOptimizations, InlineFunctions,
// memory allocation options // memory allocation options
VariableOverlap, VariableOverlap, CompactReturnDispatchParams,
// runtime check options // runtime check options
CheckIndexOutOfBounds, CheckIndexOutOfBounds,
// warning options // warning options
@ -94,6 +94,7 @@ object CompilationFlag extends Enumeration {
"ro_arrays" -> ReadOnlyArrays, "ro_arrays" -> ReadOnlyArrays,
"ror_warn" -> RorWarning, "ror_warn" -> RorWarning,
"prevent_jmp_indirect_bug" -> PreventJmpIndirectBug, "prevent_jmp_indirect_bug" -> PreventJmpIndirectBug,
"compact_dispatch_params" -> CompactReturnDispatchParams,
) )
} }

View File

@ -63,7 +63,7 @@ object Main {
ErrorReporting.info("No platform selected, defaulting to `c64`") ErrorReporting.info("No platform selected, defaulting to `c64`")
"c64" "c64"
}) })
val options = new CompilationOptions(platform, c.flags) val options = CompilationOptions(platform, c.flags)
ErrorReporting.debug("Effective flags: " + options.flags) ErrorReporting.debug("Effective flags: " + options.flags)
val output = c.outputFileName.getOrElse("a") val output = c.outputFileName.getOrElse("a")
@ -195,6 +195,9 @@ object Main {
boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) => boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) =>
c.changeFlag(CompilationFlag.VariableOverlap, v) c.changeFlag(CompilationFlag.VariableOverlap, v)
}.description("Whether variables should overlap if their scopes do not intersect.") }.description("Whether variables should overlap if their scopes do not intersect.")
boolean("-fcompact-dispatch-params", "-fno-compact-dispatch-params").action { (c, v) =>
c.changeFlag(CompilationFlag.CompactReturnDispatchParams, v)
}.description("Whether parameter values in return dispatch statements may overlap other objects.")
boolean("-fbounds-checking", "-fno-bounds-checking").action { (c, v) => boolean("-fbounds-checking", "-fno-bounds-checking").action { (c, v) =>
c.changeFlag(CompilationFlag.VariableOverlap, v) c.changeFlag(CompilationFlag.VariableOverlap, v)
}.description("Whether should insert bounds checking on array access.") }.description("Whether should insert bounds checking on array access.")

View File

@ -306,7 +306,7 @@ case class AssemblyLine(opcode: Opcode.Value, addrMode: AddrMode.Value, var para
def sizeInBytes: Int = addrMode match { def sizeInBytes: Int = addrMode match {
case Implied => 1 case Implied => 1
case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2 case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2
case AbsoluteX | Absolute | AbsoluteY | Indirect => 3 case AbsoluteIndexedX | AbsoluteX | Absolute | AbsoluteY | Indirect => 3
case DoesNotExist => 0 case DoesNotExist => 0
} }

View File

@ -9,4 +9,7 @@ import millfork.env.{Environment, MangledFunction, NormalFunction}
case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){ case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){
def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i) def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i)
def neverCheckArrayBounds: CompilationContext =
this.copy(options = options.copy(commandLineFlags = options.commandLineFlags + (CompilationFlag.CheckIndexOutOfBounds -> false)))
} }

View File

@ -1509,6 +1509,8 @@ object MfCompiler {
List(AssemblyLine.discardYF()) ++ returnInstructions) List(AssemblyLine.discardYF()) ++ returnInstructions)
} }
} }
case s : ReturnDispatchStatement =>
LinearChunk(ReturnDispatch.compile(ctx, s))
case ReturnStatement(Some(e)) => case ReturnStatement(Some(e)) =>
m.returnType match { m.returnType match {
case _: BooleanType => case _: BooleanType =>

View File

@ -0,0 +1,176 @@
package millfork.compiler
import millfork.CompilationFlag
import millfork.assembly.{AssemblyLine, OpcodeClasses}
import millfork.env._
import millfork.error.ErrorReporting
import millfork.node._
import scala.collection.mutable
/**
* @author Karol Stasiak
*/
object ReturnDispatch {
def compile(ctx: CompilationContext, stmt: ReturnDispatchStatement): List[AssemblyLine] = {
if (stmt.branches.isEmpty) {
ErrorReporting.error("At least one branch is required", stmt.position)
return Nil
}
def toConstant(e: Expression) = {
ctx.env.eval(e).getOrElse {
ErrorReporting.error("Non-constant parameter for dispatch branch", e.position)
Constant.Zero
}
}
def toInt(e: Expression): Int = {
ctx.env.eval(e) match {
case Some(NumericConstant(i, _)) =>
if (i < 0 || i > 255) ErrorReporting.error("Branch labels have to be in the 0-255 range", e.position)
i.toInt & 0xff
case _ =>
ErrorReporting.error("Branch labels have to early resolvable constants", e.position)
0
}
}
val indexerType = MfCompiler.getExpressionType(ctx, stmt.indexer)
if (indexerType.size != 1) {
ErrorReporting.error("Return dispatch index expression type has to be a byte", stmt.indexer.position)
}
if (indexerType.isSigned) {
ErrorReporting.warn("Return dispatch index expression type will be automatically casted to unsigned", ctx.options, stmt.indexer.position)
}
stmt.params.foreach{
case e@VariableExpression(name) =>
if (ctx.env.get[Variable](name).typ.size != 1) {
ErrorReporting.error("Dispatch parameters should be bytes", e.position)
}
case _ => ()
}
val returnType = ctx.function.returnType
val map = mutable.Map[Int, (Constant, List[Constant])]()
var min = Option.empty[Int]
var max = Option.empty[Int]
var default = Option.empty[(Constant, List[Constant])]
stmt.branches.foreach { branch =>
val function = ctx.env.evalForAsm(branch.function).getOrElse {
ErrorReporting.error("Non-constant function address for dispatch branch", branch.function.position)
Constant.Zero
}
if (returnType.name != "void") {
function match {
case MemoryAddressConstant(f: FunctionInMemory) =>
if (f.returnType.name != returnType.name) {
ErrorReporting.warn(s"Dispatching to a function of different return type: dispatcher return type: ${returnType.name}, dispatchee return type: ${f.returnType.name}", ctx.options, branch.function.position)
}
case _ => ()
}
}
val params = branch.params.map(toConstant)
if (params.length > stmt.params.length) {
ErrorReporting.error("Too many parameters for dispatch branch", branch.params.head.position)
}
branch.label match {
case DefaultReturnDispatchLabel(start, end) =>
if (default.isDefined) {
ErrorReporting.error(s"Duplicate default dispatch label", branch.position)
}
min = start.map(toInt)
max = end.map(toInt)
default = Some(function -> params)
case StandardReturnDispatchLabel(labels) =>
labels.foreach { label =>
val i = toInt(label)
if (map.contains(i)) {
ErrorReporting.error(s"Duplicate dispatch label: $label = $i", label.position)
}
map(i) = function -> params
}
}
}
val nonDefaultMin = map.keys.reduceOption(_ min _)
val nonDefaultMax = map.keys.reduceOption(_ max _)
val defaultMin = min.orElse(nonDefaultMin).getOrElse(0)
val defaultMax = max.orElse(nonDefaultMax).getOrElse {
ErrorReporting.error("Undefined maximum label for dispatch", stmt.position)
defaultMin
}
val actualMin = defaultMin min nonDefaultMin.getOrElse(defaultMin)
val actualMax = defaultMax max nonDefaultMax.getOrElse(defaultMax)
val zeroes = Constant.Zero -> List[Constant]()
for (i <- actualMin to actualMax) {
if (!map.contains(i)) map(i) = default.getOrElse {
// TODO: warning?
zeroes
}
}
val compactParams = ctx.options.flag(CompilationFlag.CompactReturnDispatchParams)
val paramMins = stmt.params.indices.map { paramIndex =>
if (compactParams) map.filter(_._2._2.length > paramIndex).keys.reduceOption(_ min _).getOrElse(0)
else actualMin
}
val paramMaxes = stmt.params.indices.map { paramIndex =>
if (compactParams) map.filter(_._2._2.length > paramIndex).keys.reduceOption(_ max _).getOrElse(0)
else actualMax
}
var env = ctx.env
while (env.parent.isDefined) env = env.parent.get
val label = MfCompiler.nextLabel("di")
val paramArrays = stmt.params.indices.map { ix =>
val a = InitializedArray(label + "$" + ix + ".array", None, (paramMins(ix) to paramMaxes(ix)).map { key =>
map(key)._2.lift(ix).getOrElse(Constant.Zero)
}.toList)
env.registerUnnamedArray(a)
a
}
val useJmpaix = ctx.options.flag(CompilationFlag.EmitCmosOpcodes) && (actualMax - actualMin) <= 127
val b = ctx.env.get[Type]("byte")
import millfork.assembly.AddrMode._
import millfork.assembly.Opcode._
val ctxForStoringParams = ctx.neverCheckArrayBounds
val copyParams = stmt.params.zipWithIndex.flatMap { case (paramVar, paramIndex) =>
val storeParam = MfCompiler.compileByteStorage(ctxForStoringParams, Register.A, paramVar)
if (storeParam.exists(l => OpcodeClasses.ChangesX(l.opcode)))
ErrorReporting.error("Invalid/too complex target parameter variable", paramVar.position)
AssemblyLine.absoluteX(LDA, paramArrays(paramIndex), -paramMins(paramIndex)) :: storeParam
}
if (useJmpaix) {
val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(map(i)._1.loByte, map(i)._1.hiByte)).toList)
env.registerUnnamedArray(jumpTable)
if (copyParams.isEmpty) {
val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.A, b)), BranchSpec.None)
loadIndex ++ List(AssemblyLine.implied(ASL), AssemblyLine.implied(TAX)) ++ copyParams :+ AssemblyLine(JMP, AbsoluteIndexedX, jumpTable.toAddress - actualMin * 2)
} else {
val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.X, b)), BranchSpec.None)
loadIndex ++ copyParams ++ List(
AssemblyLine.implied(TXA),
AssemblyLine.implied(ASL),
AssemblyLine.implied(TAX),
AssemblyLine(JMP, AbsoluteIndexedX, jumpTable.toAddress - actualMin * 2))
}
} else {
val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.X, b)), BranchSpec.None)
val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).loByte).toList)
val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).hiByte).toList)
env.registerUnnamedArray(jumpTableLo)
env.registerUnnamedArray(jumpTableHi)
loadIndex ++ copyParams ++ List(
AssemblyLine.absoluteX(LDA, jumpTableHi.toAddress - actualMin),
AssemblyLine.implied(PHA),
AssemblyLine.absoluteX(LDA, jumpTableLo.toAddress - actualMin),
AssemblyLine.implied(PHA),
AssemblyLine.implied(RTS))
}
}
}

View File

@ -36,7 +36,7 @@ sealed trait Constant {
def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that)) def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that))
def -(that: Long): Constant = this + (-that) def -(that: Long): Constant = if (that == 0) this else this - NumericConstant(that, minimumSize(that))
def loByte: Constant = { def loByte: Constant = {
if (requiredSize == 1) return this if (requiredSize == 1) return this
@ -87,6 +87,8 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant {
override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that)) override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that))
override def -(that: Long) = NumericConstant(value - that, minimumSize(value - that))
override def toString: String = if (value > 9) value.formatted("$%X") else value.toString override def toString: String = if (value > 9) value.formatted("$%X") else value.toString
override def isRelatedTo(v: Variable): Boolean = false override def isRelatedTo(v: Variable): Boolean = false
@ -115,7 +117,7 @@ case class HalfWordConstant(base: Constant, hi: Boolean) extends Constant {
override def requiredSize = 1 override def requiredSize = 1
override def toString: String = base + (if (hi) ".hi" else ".lo") override def toString: String = (if (base.isInstanceOf[CompoundConstant]) s"($base)" else base) + (if (hi) ".hi" else ".lo")
override def isRelatedTo(v: Variable): Boolean = base.isRelatedTo(v) override def isRelatedTo(v: Variable): Boolean = base.isRelatedTo(v)
} }
@ -192,6 +194,8 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co
} }
} }
override def -(that: Long): Constant = this + (-that)
override def +(that: Long): Constant = { override def +(that: Long): Constant = {
if (that == 0) { if (that == 0) {
return this return this

View File

@ -434,7 +434,7 @@ class Environment(val parent: Option[Environment], val prefix: String) {
} }
val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement]) val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement])
if (stmt.inlined) { if (stmt.inlined) {
val mangled = new InlinedFunction( val mangled = InlinedFunction(
name, name,
resultType, resultType,
params, params,
@ -500,6 +500,16 @@ class Environment(val parent: Option[Environment], val prefix: String) {
} }
} }
def registerUnnamedArray(array: InitializedArray): Unit = {
val b = get[Type]("byte")
val p = get[Type]("pointer")
if (!array.name.endsWith(".array")) ???
val pointerName = array.name.stripSuffix(".array")
addThing(ConstantThing(pointerName, array.toAddress, p), None)
addThing(ConstantThing(pointerName + ".addr", array.toAddress, p), None)
addThing(array, None)
}
def registerArray(stmt: ArrayDeclarationStatement): Unit = { def registerArray(stmt: ArrayDeclarationStatement): Unit = {
val b = get[Type]("byte") val b = get[Type]("byte")
val p = get[Type]("pointer") val p = get[Type]("pointer")

View File

@ -223,7 +223,9 @@ case class NormalFunction(name: String,
override def shouldGenerate = true override def shouldGenerate = true
} }
case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing {
def map(f: Constant => Constant) = ConstantThing("", f(value), typ)
}
trait ParamSignature { trait ParamSignature {
def types: List[Type] def types: List[Type]

View File

@ -138,6 +138,26 @@ case class ReturnStatement(value: Option[Expression]) extends ExecutableStatemen
override def getAllExpressions: List[Expression] = value.toList override def getAllExpressions: List[Expression] = value.toList
} }
trait ReturnDispatchLabel extends Node {
def getAllExpressions: List[Expression]
}
case class DefaultReturnDispatchLabel(start: Option[Expression], end: Option[Expression]) extends ReturnDispatchLabel {
def getAllExpressions: List[Expression] = List(start, end).flatten
}
case class StandardReturnDispatchLabel(labels:List[Expression]) extends ReturnDispatchLabel {
def getAllExpressions: List[Expression] = labels
}
case class ReturnDispatchBranch(label: ReturnDispatchLabel, function: Expression, params: List[Expression]) extends Node {
def getAllExpressions: List[Expression] = label.getAllExpressions ++ params
}
case class ReturnDispatchStatement(indexer: Expression, params: List[LhsExpression], branches: List[ReturnDispatchBranch]) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = indexer :: params ++ branches.flatMap(_.getAllExpressions)
}
case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement { case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement {
override def getAllExpressions: List[Expression] = List(destination, source) override def getAllExpressions: List[Expression] = List(destination, source)
} }

View File

@ -56,6 +56,8 @@ object UnusedFunctions extends NodeOptimization {
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil))
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)
case s: ReturnDispatchStatement =>
getAllCalledFunctions(s.getAllExpressions) ++ getAllCalledFunctions(s.branches.map(_.function))
case s: Statement => getAllCalledFunctions(s.getAllExpressions) case s: Statement => getAllCalledFunctions(s.getAllExpressions)
case s: VariableExpression => List( case s: VariableExpression => List(
s.name, s.name,

View File

@ -171,6 +171,12 @@ class Assembler(private val program: Program, private val rootEnv: Environment)
} }
} }
rootEnv.things.foreach{case (name, thing) =>
if (!env.things.contains(name)) {
env.things(name) = thing
}
}
val bank0 = mem.banks(0) val bank0 = mem.banks(0)
env.allPreallocatables.foreach { env.allPreallocatables.foreach {

View File

@ -25,11 +25,12 @@ object InliningCalculator {
program.declarations.foreach{ program.declarations.foreach{
case f:FunctionDeclarationStatement => case f:FunctionDeclarationStatement =>
allFunctions += f.name allFunctions += f.name
if (f.inlined) badFunctions += f.name if (f.inlined
if (f.address.isDefined) badFunctions += f.name || f.address.isDefined
if (f.interrupt) badFunctions += f.name || f.interrupt
if (f.reentrant) badFunctions += f.name || f.reentrant
if (f.name == "main") badFunctions += f.name || f.name == "main"
|| f.statements.exists(_.lastOption.exists(_.isInstanceOf[ReturnDispatchStatement]))) badFunctions += f.name
case _ => case _ =>
} }
allFunctions --= badFunctions allFunctions --= badFunctions
@ -38,6 +39,8 @@ object InliningCalculator {
private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap {
case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList)
case ReturnDispatchStatement(index, params, branches) =>
getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function))
case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil))
case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil))
case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil)

View File

@ -191,7 +191,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
} yield { } yield {
val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString)) val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString))
val slice = optSlice.fold(data) { val slice = optSlice.fold(data) {
case (start, length) => data.drop(start.value.toInt).take(length.value.toInt) case (start, length) => data.slice(start.value.toInt, start.value.toInt + length.value.toInt)
} }
slice.map(c => LiteralExpression(c & 0xff, 1)).toList slice.map(c => LiteralExpression(c & 0xff, 1)).toList
} }
@ -212,6 +212,8 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO
def tightMlExpressionButNotCall: P[Expression] = P(mlParenExpr | mlIndexedExpression | atom) // TODO
def mlExpression(level: Int): P[Expression] = { def mlExpression(level: Int): P[Expression] = {
val allowedOperators = mlOperators.drop(level).flatten val allowedOperators = mlOperators.drop(level).flatten
@ -285,7 +287,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
case (p, l, r) => Assignment(l, r).pos(p) case (p, l, r) => Assignment(l, r).pos(p)
} }
def keywordStatement: P[ExecutableStatement] = P(returnStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement) def keywordStatement: P[ExecutableStatement] = P(returnOrDispatchStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement)
def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) } def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) }
@ -336,7 +338,34 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o
def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}" def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}"
def returnStatement: P[ExecutableStatement] = ("return" ~ !letterOrDigit ~/ HWS ~ mlExpression(nonStatementLevel).?).map(ReturnStatement) def dispatchLabel: P[ReturnDispatchLabel] =
("default" ~ !letterOrDigit ~/ AWS ~/ ("(" ~/ position("default branch range") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?).map{
case None => DefaultReturnDispatchLabel(None, None)
case Some((_, Seq())) => DefaultReturnDispatchLabel(None, None)
case Some((_, Seq(e))) => DefaultReturnDispatchLabel(None, Some(e))
case Some((_, Seq(s, e))) => DefaultReturnDispatchLabel(Some(s), Some(e))
case Some((pos, _)) =>
ErrorReporting.error("Invalid default branch declaration", Some(pos))
DefaultReturnDispatchLabel(None, None)
} | mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS).map(exprs => StandardReturnDispatchLabel(exprs.toList))
def dispatchBranch: P[ReturnDispatchBranch] = for {
pos <- position()
l <- dispatchLabel ~/ HWS ~/ "@" ~/ HWS
f <- tightMlExpressionButNotCall ~/ HWS
parameters <- ("(" ~/ position("dispatch actual parameters") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?
} yield ReturnDispatchBranch(l, f, parameters.map(_._2.toList).getOrElse(Nil)).pos(pos)
def dispatchStatementBody: P[ExecutableStatement] = for {
indexer <- "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~/ AWS ~/ "]" ~/ AWS
_ <- position("dispatch statement body")
parameters <- ("(" ~/ position("dispatch parameters") ~ AWS ~/ mlLhsExpression.rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?
_ <- AWS ~/ position("dispatch statement body") ~/ "{" ~/ AWS
branches <- dispatchBranch.rep(sep = EOL ~ !"}" ~/ Pass)
_ <- AWS ~/ "}"
} yield ReturnDispatchStatement(indexer, parameters.map(_._2.toList).getOrElse(Nil), branches.toList)
def returnOrDispatchStatement: P[ExecutableStatement] = "return" ~ !letterOrDigit ~/ HWS ~ (dispatchStatementBody | mlExpression(nonStatementLevel).?.map(ReturnStatement))
def ifStatement: P[ExecutableStatement] = for { def ifStatement: P[ExecutableStatement] = for {
condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel)

View File

@ -132,4 +132,19 @@ class AssemblySuite extends FunSuite with Matchers {
| } | }
""".stripMargin)(_.readByte(0xc000) should equal(10)) """.stripMargin)(_.readByte(0xc000) should equal(10))
} }
test("JSR") {
EmuBenchmarkRun(
"""
| byte output @$c000
| asm void main () {
| JSR thing
| RTS
| }
|
| void thing() {
| output = 10
| }
""".stripMargin)(_.readByte(0xc000) should equal(10))
}
} }

View File

@ -0,0 +1,74 @@
package millfork.test
import millfork.test.emu.{EmuBenchmarkRun, EmuCmosBenchmarkRun, EmuUnoptimizedRun}
import org.scalatest.{FunSuite, Matchers}
/**
* @author Karol Stasiak
*/
class ReturnDispatchSuite extends FunSuite with Matchers {
test("Trivial test") {
EmuCmosBenchmarkRun(
"""
| byte output @$c000
| void main () {
| byte i
| i = 1
| return [i] {
| 1 @ success
| }
| }
| void success() {
| output = 42
| }
""".stripMargin) { m =>
m.readByte(0xc000) should equal(42)
}
}
test("Parameter test") {
EmuCmosBenchmarkRun(
"""
| array output [200] @$c000
| sbyte param
| byte ptr
| const byte L = 4
| const byte R = 5
| const byte W1 = 6
| const byte W2 = 7
| void main () {
| ptr = 0
| handler(W1)
| handler(R)
| handler(W2)
| handler(R)
| handler(W1)
| handler(L)
| handler(L)
| handler(10)
| }
| void handler(byte i) {
| return [i](param) {
| L @ move($ff)
| R @ move(1)
| W1 @ write(1)
| W2 @ write(2)
| default(0,10) @ zero
| }
| }
| void move() {
| ptr += param
| }
| void write() {
| output[ptr] = param
| }
| void zero() {
| output[ptr] = 42
| }
""".stripMargin) { m =>
m.readByte(0xc000) should equal(42)
m.readByte(0xc001) should equal(2)
m.readByte(0xc002) should equal(1)
}
}
}

View File

@ -90,10 +90,12 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
Console.err.flush() Console.err.flush()
println(source) println(source)
val platform = EmuPlatform.get(cpu) val platform = EmuPlatform.get(cpu)
val options = new CompilationOptions(platform, Map( val options = CompilationOptions(platform, Map(
CompilationFlag.EmitIllegals -> this.emitIllegals, CompilationFlag.EmitIllegals -> this.emitIllegals,
CompilationFlag.DetailedFlowAnalysis -> quantum, CompilationFlag.DetailedFlowAnalysis -> quantum,
CompilationFlag.InlineFunctions -> this.inline, CompilationFlag.InlineFunctions -> this.inline,
CompilationFlag.CompactReturnDispatchParams -> true,
CompilationFlag.EmitCmosOpcodes -> (platform.cpu == millfork.Cpu.Cmos),
// CompilationFlag.CheckIndexOutOfBounds -> true, // CompilationFlag.CheckIndexOutOfBounds -> true,
)) ))
ErrorReporting.hasErrors = false ErrorReporting.hasErrors = false
@ -113,7 +115,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
val hasOptimizations = assemblyOptimizations.nonEmpty val hasOptimizations = assemblyOptimizations.nonEmpty
var unoptimizedSize = 0L var unoptimizedSize = 0L
// print asm // print unoptimized asm
env.allPreallocatables.foreach { env.allPreallocatables.foreach {
case f: NormalFunction => case f: NormalFunction =>
val result = MfCompiler.compile(CompilationContext(f.environment, f, 0, options)) val result = MfCompiler.compile(CompilationContext(f.environment, f, 0, options))
@ -129,7 +131,9 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
// compile // compile
val assembler = new Assembler(program, env) val env2 = new Environment(None, "")
env2.collectDeclarations(program, options)
val assembler = new Assembler(program, env2)
val output = assembler.assemble(callGraph, assemblyOptimizations, options) val output = assembler.assemble(callGraph, assemblyOptimizations, options)
println(";;; compiled: -----------------") println(";;; compiled: -----------------")
output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).foreach(println) output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).foreach(println)
@ -148,6 +152,11 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization],
ErrorReporting.assertNoErrors("Code generation failed") ErrorReporting.assertNoErrors("Code generation failed")
val memoryBank = assembler.mem.banks(0) val memoryBank = assembler.mem.banks(0)
if (source.contains("return [")) {
for (_ <- 0 until 10; i <- 0xfffe.to(0, -1)) {
if (memoryBank.readable(i)) memoryBank.readable(i + 1) = true
}
}
platform.cpu match { platform.cpu match {
case millfork.Cpu.Cmos => case millfork.Cpu.Cmos =>
runViaSymon(memoryBank, platform.org, CpuBehavior.CMOS_6502) runViaSymon(memoryBank, platform.org, CpuBehavior.CMOS_6502)