diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..7fb63497 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,13 @@ +#Change log + +##Current version + +* Added return dispatch statements. + +* Fixed several optimization bugs. + +* Other minor improvements. + +## 0.1 + +* Initial numbered version. \ No newline at end of file diff --git a/doc/abi/undefined-behaviour.md b/doc/abi/undefined-behaviour.md index 431569c7..8691dc60 100644 --- a/doc/abi/undefined-behaviour.md +++ b/doc/abi/undefined-behaviour.md @@ -12,6 +12,12 @@ even up to hardware damage. * reading uninitialized variables: will return undefined values +* reading variables used by return dispatch statements but not assigned a value: will return undefined values + +* returning a value from a function by return dispatch to a function of different return type: will return undefined values + +* passing an index out of range for a return dispatch statement + * stack overflow: exhausting the hardware stack due to excess recursion, excess function calls or excess stack-allocated variables * on ROM-based platforms: writing to arrays diff --git a/doc/api/command-line.md b/doc/api/command-line.md index b3cc28a1..7cf61c84 100644 --- a/doc/api/command-line.md +++ b/doc/api/command-line.md @@ -31,17 +31,28 @@ ## Code generation options -* `-fcmos-ops`, `-fno-cmos-ops` – Whether should emit CMOS opcodes. `.ini` equivalent: `emit_cmos`. +* `-fcmos-ops`, `-fno-cmos-ops` – Whether should emit CMOS opcodes. +`.ini` equivalent: `emit_cmos`. Default: yes if targeting 65C02, no otherwise. -* `-fillegals`, `-fno-illegals` – Whether should emit illegal (undocumented) NMOS opcodes. `.ini` equivalent: `emit_illegals`. +* `-fillegals`, `-fno-illegals` – Whether should emit illegal (undocumented) NMOS opcodes. +`.ini` equivalent: `emit_illegals`. Default: no. -* `-fjmp-fix`, `-fno-jmp-fix` – Whether should prevent indirect JMP bug on page boundary. `.ini` equivalent: `prevent_jmp_indirect_bug`. +* `-fjmp-fix`, `-fno-jmp-fix` – Whether should prevent indirect JMP bug on page boundary. +`.ini` equivalent: `prevent_jmp_indirect_bug`. Default: no if targeting 65C02, yes otherwise. -* `-fdecimal-mode`, `-fno-decimal-mode` – Whether decimal mode should be available. `.ini` equivalent: `decimal_mode`. +* `-fdecimal-mode`, `-fno-decimal-mode` – Whether decimal mode should be available. +`.ini` equivalent: `decimal_mode`. Default: no if targeting Ricoh, yes otherwise. -* `-fvariable-overlap`, `-fno-variable-overlap` – Whether variables should overlap if their scopes do not intersect. Default: yes. +* `-fvariable-overlap`, `-fno-variable-overlap` – Whether variables should overlap if their scopes do not intersect. +Default: yes. -* `-fbounds-checking`, `-fnobounds-checking` – Whether should insert bounds checking on array access. Default: no. +* `-fbounds-checking`, `-fnobounds-checking` – Whether should insert bounds checking on array access. +Default: no. + +* `-fcompact-dispatch-params`, `-fnocompact-dispatch-params` – +Whether parameter values in return dispatch statements may overlap other objects. +This may cause problems if the parameter table is stored next to a hardware register that has side effects when reading. +`.ini` equivalent: `compact_dispatch_params`. Default: yes. ## Optimization options diff --git a/doc/lang/syntax.md b/doc/lang/syntax.md index ec07e046..70cfa8ee 100644 --- a/doc/lang/syntax.md +++ b/doc/lang/syntax.md @@ -79,6 +79,72 @@ if { } ``` +### `return` statement + +Syntax: + +``` +return +``` +``` +return +``` + +### `return[]` statement (return dispatch) + +Syntax examples: + +``` +return [a + b] { + 0 @ underflow + 255 @ overflow + default @ nothing +} +``` +``` +return [getF()] { + 1 @ function1 + 2 @ function2 + default(5) @ functionDefault +} +``` +``` +return [i] (param1, param2) { + 1,5,8 @ function1(4, 6) + 2 @ function2(9) + default(0,20) @ functionDefault +} +``` + +Return dispatch calculates the value of an index, picks the correct branch, +assigns some global variables and jumps to another function. + +The index has to evaluate to a byte. The functions cannot be `inline` and shouldn't have parameters. +Jumping to a function with parameters gives those parameters undefined values. + +The functions are not called, so they don't return to the function the return dispatch statement is in, but to its caller. +The return values are passed along. If the dispatching function has a non-`void` return type different that the type +of the function dispatched to, the return value is undefined. + +If the `default` branch exists, then it is used for every missing index value between other supported values. +Optional parameters to `default` specify the maximum, or both the minimum and maximum supported index value. +In the above examples: the first example supports values 0–255, second 1–5, and third 0–20. + +If the index has an unsupported value, the behaviour is formally undefined, but in practice the program will simply crash. + +Before jumping to the function, the chosen global variables will be assigned parameter values. +Variables have to be global byte-sized. Some simple array indexing expressions are also allowed. +Parameter values have to be constants. +For example, in the third example one of the following will happen: + +* if `i` is 1, 5 or 8, then `param1` is assigned 4, `param2` is assigned 6 and then `function1` is called; + +* if `i` is 2, then `param1` is assigned 9, `param2` is assigned an undefined value and then `function2` is called; + +* if `i` is any other value from 0 to 20, then `param1` and `param2` are assigned undefined values and then `functionDefault` is called; + +* if `i` has any other value, then undefined behaviour. + ### `while` and `do-while` statements Syntax: diff --git a/src/main/scala/millfork/CompilationOptions.scala b/src/main/scala/millfork/CompilationOptions.scala index 0d40e8c1..dce46541 100644 --- a/src/main/scala/millfork/CompilationOptions.scala +++ b/src/main/scala/millfork/CompilationOptions.scala @@ -5,7 +5,7 @@ import millfork.error.ErrorReporting /** * @author Karol Stasiak */ -class CompilationOptions(val platform: Platform, val commandLineFlags: Map[CompilationFlag.Value, Boolean]) { +case class CompilationOptions(platform: Platform, commandLineFlags: Map[CompilationFlag.Value, Boolean]) { import CompilationFlag._ import Cpu._ @@ -46,11 +46,11 @@ object Cpu extends Enumeration { import CompilationFlag._ def defaultFlags(x: Cpu.Value): Set[CompilationFlag.Value] = x match { - case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) - case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap) - case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap) - case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap) - case Cmos => Set(EmitCmosOpcodes, VariableOverlap) + case StrictMos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams) + case Mos => Set(DecimalMode, PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams) + case Ricoh => Set(PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams) + case StrictRicoh => Set(PreventJmpIndirectBug, VariableOverlap, CompactReturnDispatchParams) + case Cmos => Set(EmitCmosOpcodes, VariableOverlap, CompactReturnDispatchParams) } def fromString(name: String): Cpu.Value = name match { @@ -77,7 +77,7 @@ object CompilationFlag extends Enumeration { // optimization options: DetailedFlowAnalysis, DangerousOptimizations, InlineFunctions, // memory allocation options - VariableOverlap, + VariableOverlap, CompactReturnDispatchParams, // runtime check options CheckIndexOutOfBounds, // warning options @@ -94,6 +94,7 @@ object CompilationFlag extends Enumeration { "ro_arrays" -> ReadOnlyArrays, "ror_warn" -> RorWarning, "prevent_jmp_indirect_bug" -> PreventJmpIndirectBug, + "compact_dispatch_params" -> CompactReturnDispatchParams, ) } \ No newline at end of file diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala index 1066179b..5f80a397 100644 --- a/src/main/scala/millfork/Main.scala +++ b/src/main/scala/millfork/Main.scala @@ -63,7 +63,7 @@ object Main { ErrorReporting.info("No platform selected, defaulting to `c64`") "c64" }) - val options = new CompilationOptions(platform, c.flags) + val options = CompilationOptions(platform, c.flags) ErrorReporting.debug("Effective flags: " + options.flags) val output = c.outputFileName.getOrElse("a") @@ -195,6 +195,9 @@ object Main { boolean("-fvariable-overlap", "-fno-variable-overlap").action { (c, v) => c.changeFlag(CompilationFlag.VariableOverlap, v) }.description("Whether variables should overlap if their scopes do not intersect.") + boolean("-fcompact-dispatch-params", "-fno-compact-dispatch-params").action { (c, v) => + c.changeFlag(CompilationFlag.CompactReturnDispatchParams, v) + }.description("Whether parameter values in return dispatch statements may overlap other objects.") boolean("-fbounds-checking", "-fno-bounds-checking").action { (c, v) => c.changeFlag(CompilationFlag.VariableOverlap, v) }.description("Whether should insert bounds checking on array access.") diff --git a/src/main/scala/millfork/assembly/AssemblyLine.scala b/src/main/scala/millfork/assembly/AssemblyLine.scala index 72b14d56..5c53e08e 100644 --- a/src/main/scala/millfork/assembly/AssemblyLine.scala +++ b/src/main/scala/millfork/assembly/AssemblyLine.scala @@ -306,7 +306,7 @@ case class AssemblyLine(opcode: Opcode.Value, addrMode: AddrMode.Value, var para def sizeInBytes: Int = addrMode match { case Implied => 1 case Relative | ZeroPageX | ZeroPage | ZeroPageY | IndexedX | IndexedY | Immediate => 2 - case AbsoluteX | Absolute | AbsoluteY | Indirect => 3 + case AbsoluteIndexedX | AbsoluteX | Absolute | AbsoluteY | Indirect => 3 case DoesNotExist => 0 } diff --git a/src/main/scala/millfork/compiler/CompilationContext.scala b/src/main/scala/millfork/compiler/CompilationContext.scala index a321eb54..f5d3def2 100644 --- a/src/main/scala/millfork/compiler/CompilationContext.scala +++ b/src/main/scala/millfork/compiler/CompilationContext.scala @@ -9,4 +9,7 @@ import millfork.env.{Environment, MangledFunction, NormalFunction} case class CompilationContext(env: Environment, function: NormalFunction, extraStackOffset: Int, options: CompilationOptions){ def addStack(i: Int): CompilationContext = this.copy(extraStackOffset = extraStackOffset + i) + + def neverCheckArrayBounds: CompilationContext = + this.copy(options = options.copy(commandLineFlags = options.commandLineFlags + (CompilationFlag.CheckIndexOutOfBounds -> false))) } diff --git a/src/main/scala/millfork/compiler/MfCompiler.scala b/src/main/scala/millfork/compiler/MfCompiler.scala index 61183cd8..f668d72a 100644 --- a/src/main/scala/millfork/compiler/MfCompiler.scala +++ b/src/main/scala/millfork/compiler/MfCompiler.scala @@ -1509,6 +1509,8 @@ object MfCompiler { List(AssemblyLine.discardYF()) ++ returnInstructions) } } + case s : ReturnDispatchStatement => + LinearChunk(ReturnDispatch.compile(ctx, s)) case ReturnStatement(Some(e)) => m.returnType match { case _: BooleanType => diff --git a/src/main/scala/millfork/compiler/ReturnDispatch.scala b/src/main/scala/millfork/compiler/ReturnDispatch.scala new file mode 100644 index 00000000..4e17164c --- /dev/null +++ b/src/main/scala/millfork/compiler/ReturnDispatch.scala @@ -0,0 +1,176 @@ +package millfork.compiler + +import millfork.CompilationFlag +import millfork.assembly.{AssemblyLine, OpcodeClasses} +import millfork.env._ +import millfork.error.ErrorReporting +import millfork.node._ + +import scala.collection.mutable + +/** + * @author Karol Stasiak + */ +object ReturnDispatch { + + def compile(ctx: CompilationContext, stmt: ReturnDispatchStatement): List[AssemblyLine] = { + if (stmt.branches.isEmpty) { + ErrorReporting.error("At least one branch is required", stmt.position) + return Nil + } + + def toConstant(e: Expression) = { + ctx.env.eval(e).getOrElse { + ErrorReporting.error("Non-constant parameter for dispatch branch", e.position) + Constant.Zero + } + } + + def toInt(e: Expression): Int = { + ctx.env.eval(e) match { + case Some(NumericConstant(i, _)) => + if (i < 0 || i > 255) ErrorReporting.error("Branch labels have to be in the 0-255 range", e.position) + i.toInt & 0xff + case _ => + ErrorReporting.error("Branch labels have to early resolvable constants", e.position) + 0 + } + } + + val indexerType = MfCompiler.getExpressionType(ctx, stmt.indexer) + if (indexerType.size != 1) { + ErrorReporting.error("Return dispatch index expression type has to be a byte", stmt.indexer.position) + } + if (indexerType.isSigned) { + ErrorReporting.warn("Return dispatch index expression type will be automatically casted to unsigned", ctx.options, stmt.indexer.position) + } + stmt.params.foreach{ + case e@VariableExpression(name) => + if (ctx.env.get[Variable](name).typ.size != 1) { + ErrorReporting.error("Dispatch parameters should be bytes", e.position) + } + case _ => () + } + + val returnType = ctx.function.returnType + val map = mutable.Map[Int, (Constant, List[Constant])]() + var min = Option.empty[Int] + var max = Option.empty[Int] + var default = Option.empty[(Constant, List[Constant])] + stmt.branches.foreach { branch => + val function = ctx.env.evalForAsm(branch.function).getOrElse { + ErrorReporting.error("Non-constant function address for dispatch branch", branch.function.position) + Constant.Zero + } + if (returnType.name != "void") { + function match { + case MemoryAddressConstant(f: FunctionInMemory) => + if (f.returnType.name != returnType.name) { + ErrorReporting.warn(s"Dispatching to a function of different return type: dispatcher return type: ${returnType.name}, dispatchee return type: ${f.returnType.name}", ctx.options, branch.function.position) + } + case _ => () + } + } + val params = branch.params.map(toConstant) + if (params.length > stmt.params.length) { + ErrorReporting.error("Too many parameters for dispatch branch", branch.params.head.position) + } + branch.label match { + case DefaultReturnDispatchLabel(start, end) => + if (default.isDefined) { + ErrorReporting.error(s"Duplicate default dispatch label", branch.position) + } + min = start.map(toInt) + max = end.map(toInt) + default = Some(function -> params) + case StandardReturnDispatchLabel(labels) => + labels.foreach { label => + val i = toInt(label) + if (map.contains(i)) { + ErrorReporting.error(s"Duplicate dispatch label: $label = $i", label.position) + } + map(i) = function -> params + } + } + } + val nonDefaultMin = map.keys.reduceOption(_ min _) + val nonDefaultMax = map.keys.reduceOption(_ max _) + val defaultMin = min.orElse(nonDefaultMin).getOrElse(0) + val defaultMax = max.orElse(nonDefaultMax).getOrElse { + ErrorReporting.error("Undefined maximum label for dispatch", stmt.position) + defaultMin + } + val actualMin = defaultMin min nonDefaultMin.getOrElse(defaultMin) + val actualMax = defaultMax max nonDefaultMax.getOrElse(defaultMax) + val zeroes = Constant.Zero -> List[Constant]() + for (i <- actualMin to actualMax) { + if (!map.contains(i)) map(i) = default.getOrElse { + // TODO: warning? + zeroes + } + } + + val compactParams = ctx.options.flag(CompilationFlag.CompactReturnDispatchParams) + val paramMins = stmt.params.indices.map { paramIndex => + if (compactParams) map.filter(_._2._2.length > paramIndex).keys.reduceOption(_ min _).getOrElse(0) + else actualMin + } + val paramMaxes = stmt.params.indices.map { paramIndex => + if (compactParams) map.filter(_._2._2.length > paramIndex).keys.reduceOption(_ max _).getOrElse(0) + else actualMax + } + + var env = ctx.env + while (env.parent.isDefined) env = env.parent.get + val label = MfCompiler.nextLabel("di") + val paramArrays = stmt.params.indices.map { ix => + val a = InitializedArray(label + "$" + ix + ".array", None, (paramMins(ix) to paramMaxes(ix)).map { key => + map(key)._2.lift(ix).getOrElse(Constant.Zero) + }.toList) + env.registerUnnamedArray(a) + a + } + + val useJmpaix = ctx.options.flag(CompilationFlag.EmitCmosOpcodes) && (actualMax - actualMin) <= 127 + val b = ctx.env.get[Type]("byte") + + import millfork.assembly.AddrMode._ + import millfork.assembly.Opcode._ + + val ctxForStoringParams = ctx.neverCheckArrayBounds + val copyParams = stmt.params.zipWithIndex.flatMap { case (paramVar, paramIndex) => + val storeParam = MfCompiler.compileByteStorage(ctxForStoringParams, Register.A, paramVar) + if (storeParam.exists(l => OpcodeClasses.ChangesX(l.opcode))) + ErrorReporting.error("Invalid/too complex target parameter variable", paramVar.position) + AssemblyLine.absoluteX(LDA, paramArrays(paramIndex), -paramMins(paramIndex)) :: storeParam + } + + if (useJmpaix) { + val jumpTable = InitializedArray(label + "$jt.array", None, (actualMin to actualMax).flatMap(i => List(map(i)._1.loByte, map(i)._1.hiByte)).toList) + env.registerUnnamedArray(jumpTable) + if (copyParams.isEmpty) { + val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.A, b)), BranchSpec.None) + loadIndex ++ List(AssemblyLine.implied(ASL), AssemblyLine.implied(TAX)) ++ copyParams :+ AssemblyLine(JMP, AbsoluteIndexedX, jumpTable.toAddress - actualMin * 2) + } else { + val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.X, b)), BranchSpec.None) + loadIndex ++ copyParams ++ List( + AssemblyLine.implied(TXA), + AssemblyLine.implied(ASL), + AssemblyLine.implied(TAX), + AssemblyLine(JMP, AbsoluteIndexedX, jumpTable.toAddress - actualMin * 2)) + } + } else { + val loadIndex = MfCompiler.compile(ctx, stmt.indexer, Some(b -> RegisterVariable(Register.X, b)), BranchSpec.None) + val jumpTableLo = InitializedArray(label + "$jl.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).loByte).toList) + val jumpTableHi = InitializedArray(label + "$jh.array", None, (actualMin to actualMax).map(i => (map(i)._1 - 1).hiByte).toList) + env.registerUnnamedArray(jumpTableLo) + env.registerUnnamedArray(jumpTableHi) + loadIndex ++ copyParams ++ List( + AssemblyLine.absoluteX(LDA, jumpTableHi.toAddress - actualMin), + AssemblyLine.implied(PHA), + AssemblyLine.absoluteX(LDA, jumpTableLo.toAddress - actualMin), + AssemblyLine.implied(PHA), + AssemblyLine.implied(RTS)) + } + } +} diff --git a/src/main/scala/millfork/env/Constant.scala b/src/main/scala/millfork/env/Constant.scala index 1faac691..ebcd4a1f 100644 --- a/src/main/scala/millfork/env/Constant.scala +++ b/src/main/scala/millfork/env/Constant.scala @@ -36,7 +36,7 @@ sealed trait Constant { def +(that: Long): Constant = if (that == 0) this else this + NumericConstant(that, minimumSize(that)) - def -(that: Long): Constant = this + (-that) + def -(that: Long): Constant = if (that == 0) this else this - NumericConstant(that, minimumSize(that)) def loByte: Constant = { if (requiredSize == 1) return this @@ -87,6 +87,8 @@ case class NumericConstant(value: Long, requiredSize: Int) extends Constant { override def +(that: Long) = NumericConstant(value + that, minimumSize(value + that)) + override def -(that: Long) = NumericConstant(value - that, minimumSize(value - that)) + override def toString: String = if (value > 9) value.formatted("$%X") else value.toString override def isRelatedTo(v: Variable): Boolean = false @@ -115,7 +117,7 @@ case class HalfWordConstant(base: Constant, hi: Boolean) extends Constant { override def requiredSize = 1 - override def toString: String = base + (if (hi) ".hi" else ".lo") + override def toString: String = (if (base.isInstanceOf[CompoundConstant]) s"($base)" else base) + (if (hi) ".hi" else ".lo") override def isRelatedTo(v: Variable): Boolean = base.isRelatedTo(v) } @@ -192,6 +194,8 @@ case class CompoundConstant(operator: MathOperator.Value, lhs: Constant, rhs: Co } } + override def -(that: Long): Constant = this + (-that) + override def +(that: Long): Constant = { if (that == 0) { return this diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 3b2c1138..6409f4c2 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -434,7 +434,7 @@ class Environment(val parent: Option[Environment], val prefix: String) { } val needsExtraRTS = !stmt.inlined && !stmt.assembly && (statements.isEmpty || !statements.last.isInstanceOf[ReturnStatement]) if (stmt.inlined) { - val mangled = new InlinedFunction( + val mangled = InlinedFunction( name, resultType, params, @@ -500,6 +500,16 @@ class Environment(val parent: Option[Environment], val prefix: String) { } } + def registerUnnamedArray(array: InitializedArray): Unit = { + val b = get[Type]("byte") + val p = get[Type]("pointer") + if (!array.name.endsWith(".array")) ??? + val pointerName = array.name.stripSuffix(".array") + addThing(ConstantThing(pointerName, array.toAddress, p), None) + addThing(ConstantThing(pointerName + ".addr", array.toAddress, p), None) + addThing(array, None) + } + def registerArray(stmt: ArrayDeclarationStatement): Unit = { val b = get[Type]("byte") val p = get[Type]("pointer") diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index bf1e3345..5b0fa244 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -223,7 +223,9 @@ case class NormalFunction(name: String, override def shouldGenerate = true } -case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing +case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing { + def map(f: Constant => Constant) = ConstantThing("", f(value), typ) +} trait ParamSignature { def types: List[Type] diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index 9ec3e996..412c806f 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -138,6 +138,26 @@ case class ReturnStatement(value: Option[Expression]) extends ExecutableStatemen override def getAllExpressions: List[Expression] = value.toList } +trait ReturnDispatchLabel extends Node { + def getAllExpressions: List[Expression] +} + +case class DefaultReturnDispatchLabel(start: Option[Expression], end: Option[Expression]) extends ReturnDispatchLabel { + def getAllExpressions: List[Expression] = List(start, end).flatten +} + +case class StandardReturnDispatchLabel(labels:List[Expression]) extends ReturnDispatchLabel { + def getAllExpressions: List[Expression] = labels +} + +case class ReturnDispatchBranch(label: ReturnDispatchLabel, function: Expression, params: List[Expression]) extends Node { + def getAllExpressions: List[Expression] = label.getAllExpressions ++ params +} + +case class ReturnDispatchStatement(indexer: Expression, params: List[LhsExpression], branches: List[ReturnDispatchBranch]) extends ExecutableStatement { + override def getAllExpressions: List[Expression] = indexer :: params ++ branches.flatMap(_.getAllExpressions) +} + case class Assignment(destination: LhsExpression, source: Expression) extends ExecutableStatement { override def getAllExpressions: List[Expression] = List(destination, source) } diff --git a/src/main/scala/millfork/node/opt/UnusedFunctions.scala b/src/main/scala/millfork/node/opt/UnusedFunctions.scala index 726245bc..f4f9d23d 100644 --- a/src/main/scala/millfork/node/opt/UnusedFunctions.scala +++ b/src/main/scala/millfork/node/opt/UnusedFunctions.scala @@ -56,6 +56,8 @@ object UnusedFunctions extends NodeOptimization { case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) + case s: ReturnDispatchStatement => + getAllCalledFunctions(s.getAllExpressions) ++ getAllCalledFunctions(s.branches.map(_.function)) case s: Statement => getAllCalledFunctions(s.getAllExpressions) case s: VariableExpression => List( s.name, diff --git a/src/main/scala/millfork/output/Assembler.scala b/src/main/scala/millfork/output/Assembler.scala index f94b1db4..c203966d 100644 --- a/src/main/scala/millfork/output/Assembler.scala +++ b/src/main/scala/millfork/output/Assembler.scala @@ -171,6 +171,12 @@ class Assembler(private val program: Program, private val rootEnv: Environment) } } + rootEnv.things.foreach{case (name, thing) => + if (!env.things.contains(name)) { + env.things(name) = thing + } + } + val bank0 = mem.banks(0) env.allPreallocatables.foreach { diff --git a/src/main/scala/millfork/output/InliningCalculator.scala b/src/main/scala/millfork/output/InliningCalculator.scala index 3b309a12..a84f5104 100644 --- a/src/main/scala/millfork/output/InliningCalculator.scala +++ b/src/main/scala/millfork/output/InliningCalculator.scala @@ -25,11 +25,12 @@ object InliningCalculator { program.declarations.foreach{ case f:FunctionDeclarationStatement => allFunctions += f.name - if (f.inlined) badFunctions += f.name - if (f.address.isDefined) badFunctions += f.name - if (f.interrupt) badFunctions += f.name - if (f.reentrant) badFunctions += f.name - if (f.name == "main") badFunctions += f.name + if (f.inlined + || f.address.isDefined + || f.interrupt + || f.reentrant + || f.name == "main" + || f.statements.exists(_.lastOption.exists(_.isInstanceOf[ReturnDispatchStatement]))) badFunctions += f.name case _ => } allFunctions --= badFunctions @@ -38,6 +39,8 @@ object InliningCalculator { private def getAllCalledFunctions(expressions: List[Node]): List[(String, Boolean)] = expressions.flatMap { case s: VariableDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.initialValue.toList) + case ReturnDispatchStatement(index, params, branches) => + getAllCalledFunctions(List(index)) ++ getAllCalledFunctions(params) ++ getAllCalledFunctions(branches.map(b => b.function)) case s: ArrayDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.elements.getOrElse(Nil)) case s: FunctionDeclarationStatement => getAllCalledFunctions(s.address.toList) ++ getAllCalledFunctions(s.statements.getOrElse(Nil)) case Assignment(VariableExpression(_), expr) => getAllCalledFunctions(expr :: Nil) diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 8331efb8..815d9bfd 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -191,7 +191,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o } yield { val data = Files.readAllBytes(Paths.get(currentDirectory, filePath.mkString)) val slice = optSlice.fold(data) { - case (start, length) => data.drop(start.value.toInt).take(length.value.toInt) + case (start, length) => data.slice(start.value.toInt, start.value.toInt + length.value.toInt) } slice.map(c => LiteralExpression(c & 0xff, 1)).toList } @@ -212,6 +212,8 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o def tightMlExpression: P[Expression] = P(mlParenExpr | functionCall | mlIndexedExpression | atom) // TODO + def tightMlExpressionButNotCall: P[Expression] = P(mlParenExpr | mlIndexedExpression | atom) // TODO + def mlExpression(level: Int): P[Expression] = { val allowedOperators = mlOperators.drop(level).flatten @@ -285,7 +287,7 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o case (p, l, r) => Assignment(l, r).pos(p) } - def keywordStatement: P[ExecutableStatement] = P(returnStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement) + def keywordStatement: P[ExecutableStatement] = P(returnOrDispatchStatement | ifStatement | whileStatement | forStatement | doWhileStatement | inlineAssembly | assignmentStatement) def executableStatement: P[ExecutableStatement] = (position() ~ P(keywordStatement | expressionStatement)).map { case (p, s) => s.pos(p) } @@ -336,7 +338,34 @@ case class MfParser(filename: String, input: String, currentDirectory: String, o def executableStatements: P[Seq[ExecutableStatement]] = "{" ~/ AWS ~/ executableStatement.rep(sep = EOL ~ !"}" ~/ Pass) ~/ AWS ~ "}" - def returnStatement: P[ExecutableStatement] = ("return" ~ !letterOrDigit ~/ HWS ~ mlExpression(nonStatementLevel).?).map(ReturnStatement) + def dispatchLabel: P[ReturnDispatchLabel] = + ("default" ~ !letterOrDigit ~/ AWS ~/ ("(" ~/ position("default branch range") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").?).map{ + case None => DefaultReturnDispatchLabel(None, None) + case Some((_, Seq())) => DefaultReturnDispatchLabel(None, None) + case Some((_, Seq(e))) => DefaultReturnDispatchLabel(None, Some(e)) + case Some((_, Seq(s, e))) => DefaultReturnDispatchLabel(Some(s), Some(e)) + case Some((pos, _)) => + ErrorReporting.error("Invalid default branch declaration", Some(pos)) + DefaultReturnDispatchLabel(None, None) + } | mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS).map(exprs => StandardReturnDispatchLabel(exprs.toList)) + + def dispatchBranch: P[ReturnDispatchBranch] = for { + pos <- position() + l <- dispatchLabel ~/ HWS ~/ "@" ~/ HWS + f <- tightMlExpressionButNotCall ~/ HWS + parameters <- ("(" ~/ position("dispatch actual parameters") ~ AWS ~/ mlExpression(nonStatementLevel).rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").? + } yield ReturnDispatchBranch(l, f, parameters.map(_._2.toList).getOrElse(Nil)).pos(pos) + + def dispatchStatementBody: P[ExecutableStatement] = for { + indexer <- "[" ~/ AWS ~/ mlExpression(nonStatementLevel) ~/ AWS ~/ "]" ~/ AWS + _ <- position("dispatch statement body") + parameters <- ("(" ~/ position("dispatch parameters") ~ AWS ~/ mlLhsExpression.rep(min = 0, sep = AWS ~ "," ~/ AWS) ~ AWS ~/ ")" ~/ "").? + _ <- AWS ~/ position("dispatch statement body") ~/ "{" ~/ AWS + branches <- dispatchBranch.rep(sep = EOL ~ !"}" ~/ Pass) + _ <- AWS ~/ "}" + } yield ReturnDispatchStatement(indexer, parameters.map(_._2.toList).getOrElse(Nil), branches.toList) + + def returnOrDispatchStatement: P[ExecutableStatement] = "return" ~ !letterOrDigit ~/ HWS ~ (dispatchStatementBody | mlExpression(nonStatementLevel).?.map(ReturnStatement)) def ifStatement: P[ExecutableStatement] = for { condition <- "if" ~ !letterOrDigit ~/ HWS ~/ mlExpression(nonStatementLevel) diff --git a/src/test/scala/millfork/test/AssemblySuite.scala b/src/test/scala/millfork/test/AssemblySuite.scala index a4da6086..deb0b6b8 100644 --- a/src/test/scala/millfork/test/AssemblySuite.scala +++ b/src/test/scala/millfork/test/AssemblySuite.scala @@ -132,4 +132,19 @@ class AssemblySuite extends FunSuite with Matchers { | } """.stripMargin)(_.readByte(0xc000) should equal(10)) } + + test("JSR") { + EmuBenchmarkRun( + """ + | byte output @$c000 + | asm void main () { + | JSR thing + | RTS + | } + | + | void thing() { + | output = 10 + | } + """.stripMargin)(_.readByte(0xc000) should equal(10)) + } } diff --git a/src/test/scala/millfork/test/ReturnDispatchSuite.scala b/src/test/scala/millfork/test/ReturnDispatchSuite.scala new file mode 100644 index 00000000..7116c634 --- /dev/null +++ b/src/test/scala/millfork/test/ReturnDispatchSuite.scala @@ -0,0 +1,74 @@ +package millfork.test + +import millfork.test.emu.{EmuBenchmarkRun, EmuCmosBenchmarkRun, EmuUnoptimizedRun} +import org.scalatest.{FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class ReturnDispatchSuite extends FunSuite with Matchers { + + test("Trivial test") { + EmuCmosBenchmarkRun( + """ + | byte output @$c000 + | void main () { + | byte i + | i = 1 + | return [i] { + | 1 @ success + | } + | } + | void success() { + | output = 42 + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(42) + } + } + test("Parameter test") { + EmuCmosBenchmarkRun( + """ + | array output [200] @$c000 + | sbyte param + | byte ptr + | const byte L = 4 + | const byte R = 5 + | const byte W1 = 6 + | const byte W2 = 7 + | void main () { + | ptr = 0 + | handler(W1) + | handler(R) + | handler(W2) + | handler(R) + | handler(W1) + | handler(L) + | handler(L) + | handler(10) + | } + | void handler(byte i) { + | return [i](param) { + | L @ move($ff) + | R @ move(1) + | W1 @ write(1) + | W2 @ write(2) + | default(0,10) @ zero + | } + | } + | void move() { + | ptr += param + | } + | void write() { + | output[ptr] = param + | } + | void zero() { + | output[ptr] = 42 + | } + """.stripMargin) { m => + m.readByte(0xc000) should equal(42) + m.readByte(0xc001) should equal(2) + m.readByte(0xc002) should equal(1) + } + } +} diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index ac8c6658..383d8fe3 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -90,10 +90,12 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], Console.err.flush() println(source) val platform = EmuPlatform.get(cpu) - val options = new CompilationOptions(platform, Map( + val options = CompilationOptions(platform, Map( CompilationFlag.EmitIllegals -> this.emitIllegals, CompilationFlag.DetailedFlowAnalysis -> quantum, CompilationFlag.InlineFunctions -> this.inline, + CompilationFlag.CompactReturnDispatchParams -> true, + CompilationFlag.EmitCmosOpcodes -> (platform.cpu == millfork.Cpu.Cmos), // CompilationFlag.CheckIndexOutOfBounds -> true, )) ErrorReporting.hasErrors = false @@ -113,7 +115,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], val hasOptimizations = assemblyOptimizations.nonEmpty var unoptimizedSize = 0L - // print asm + // print unoptimized asm env.allPreallocatables.foreach { case f: NormalFunction => val result = MfCompiler.compile(CompilationContext(f.environment, f, 0, options)) @@ -129,7 +131,9 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], // compile - val assembler = new Assembler(program, env) + val env2 = new Environment(None, "") + env2.collectDeclarations(program, options) + val assembler = new Assembler(program, env2) val output = assembler.assemble(callGraph, assemblyOptimizations, options) println(";;; compiled: -----------------") output.asm.takeWhile(s => !(s.startsWith(".") && s.contains("= $"))).foreach(println) @@ -148,6 +152,11 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], ErrorReporting.assertNoErrors("Code generation failed") val memoryBank = assembler.mem.banks(0) + if (source.contains("return [")) { + for (_ <- 0 until 10; i <- 0xfffe.to(0, -1)) { + if (memoryBank.readable(i)) memoryBank.readable(i + 1) = true + } + } platform.cpu match { case millfork.Cpu.Cmos => runViaSymon(memoryBank, platform.org, CpuBehavior.CMOS_6502)