From 35ba36ce11a390ce8eccf5498efaadb47af7a985 Mon Sep 17 00:00:00 2001 From: Karol Stasiak Date: Sat, 27 Jul 2019 00:58:10 +0200 Subject: [PATCH] =?UTF-8?q?Function=20pointers=20=E2=80=93=20initial=20ver?= =?UTF-8?q?sion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 + docs/lang/operators.md | 14 ++- docs/lang/types.md | 19 ++++ examples/README.md | 2 + examples/crossplatform/fizzbuzz2.mfk | 38 +++++++ include/stdlib_6502.mfk | 6 +- include/stdlib_i80.mfk | 5 + include/zp_reg.mfk | 8 ++ .../opt/ZeropageRegisterOptimizations.scala | 1 + .../z80/opt/ReverseFlowAnalyzer.scala | 10 +- .../compiler/AbstractExpressionCompiler.scala | 30 +++++- .../compiler/mos/MosExpressionCompiler.scala | 38 ++++++- .../compiler/z80/Z80ExpressionCompiler.scala | 31 ++++++ src/main/scala/millfork/env/Environment.scala | 44 ++++++-- src/main/scala/millfork/env/Thing.scala | 38 ++++++- src/main/scala/millfork/node/CallGraph.scala | 8 +- .../output/AbstractInliningCalculator.scala | 3 + .../millfork/test/FunctionPointerSuite.scala | 102 ++++++++++++++++++ src/test/scala/millfork/test/emu/EmuRun.scala | 1 + .../scala/millfork/test/emu/EmuZ80Run.scala | 5 + .../millfork/test/emu/ShouldNotCompile.scala | 20 +++- 21 files changed, 399 insertions(+), 26 deletions(-) create mode 100644 examples/crossplatform/fizzbuzz2.mfk create mode 100644 src/test/scala/millfork/test/FunctionPointerSuite.scala diff --git a/CHANGELOG.md b/CHANGELOG.md index 9291e1bf..6c274c60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ * Added `bool` type. +* Added function pointers – so far quite limited. + * Added arrays of elements of size greater than byte. * Improved passing of register parameters to assembly functions. diff --git a/docs/lang/operators.md b/docs/lang/operators.md index 459eb099..550ccb71 100644 --- a/docs/lang/operators.md +++ b/docs/lang/operators.md @@ -264,12 +264,10 @@ Note that you cannot access a whole array element if it's bigger than 2 bytes, b Other kinds of expressions than the above (even `nonet(byte + byte + byte)`) will not work as expected. * `hi`, `lo`: most/least significant byte of a word -`hi(word)` - +`hi(word)` Furthermore, any type that can be assigned to a variable can be used to convert either from one type either to another type of the same size, -or from a 1-byte integer type to a compatible 2-byte integer type. - +or from a 1-byte integer type to a compatible 2-byte integer type. `byte` → `word` `word` → `pointer` some enum → `byte` @@ -281,4 +279,12 @@ some enum → `word` * `sizeof`: size of the argument in bytes; the argument can be an expression or a type, and the result is a constant of either `byte` or `word` type, depending on situation +* `call`: calls a function via a pointer; +the first argument is the pointer to the function; +the second argument, if present, is the argument to the called function. +The function can have max one parameter, of size max 1 byte, and may return a value of size max 2 bytes. +You can't create typed pointers to other kinds of functions anyway. +If the pointed-to function returns a value, then the result of `call(...)` is the result of the function. +Using `call` on 6502 targets requires at least 4 bytes of zeropage pseudoregister. + diff --git a/docs/lang/types.md b/docs/lang/types.md index 2908ed85..cd2ac3a8 100644 --- a/docs/lang/types.md +++ b/docs/lang/types.md @@ -75,6 +75,25 @@ Its actual value is defined using the feature `NULLPTR`, by default it's 0. `nullptr` isn't directly assignable to non-pointer types. +## Function pointers + +For every type `A` of size 1 (or `void`) and every type `B` of size 1 or 2 (or `void`), +there is a pointer type defined called `function.A.to.B`, which represents functions with a signature like this: + + B function_name(A parameter) + B function_name() // if A is void + +Examples: + + word i + function.void.to.word p1 = f1.pointer + i = call(p1) + function.byte.to.byte p2 = f2.pointer + i += call(p2, 7) + +Using `call` on 6502 requires at least 4 bytes of zeropage pseudoregister. + + ## Boolean types TODO diff --git a/examples/README.md b/examples/README.md index 9fbe2aca..00f130c4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -6,6 +6,8 @@ * [Fizzbuzz](crossplatform/fizzbuzz.mfk) (C64/C16/PET/VIC-20/PET/Atari/Apple II/BBC Micro/ZX Spectrum/PC-88/Armstrad CPC/MSX) – everyone's favourite programming task +* [Fizzbuzz 2](crossplatform/fizzbuzz2.mfk) (C64/C16/PET/VIC-20/PET/Atari/Apple II/BBC Micro/ZX Spectrum/PC-88/Armstrad CPC/MSX) – an alternative, more extensible implemententation of fizzbuzz + * [Text encodings](crossplatform/text_encodings.mfk) (C64/ZX Spectrum) – examples of text encoding features * [Echo](crossplatform/echo.mfk) (C64/C16/ZX Spectrum/PC-88/MSX)– simple text input and output diff --git a/examples/crossplatform/fizzbuzz2.mfk b/examples/crossplatform/fizzbuzz2.mfk new file mode 100644 index 00000000..71ae35c6 --- /dev/null +++ b/examples/crossplatform/fizzbuzz2.mfk @@ -0,0 +1,38 @@ + +import stdio + +bool divisible3(byte x) = x %% 3 == 0 +bool divisible5(byte x) = x %% 5 == 0 + +struct stage { + function.byte.to.bool predicate + pointer text +} + +// can't put text literals directly in struct constructors yet +array fizz = "fizz"z +array buzz = "buzz"z + +array(stage) stages = [ + stage(divisible3.pointer, fizz), + stage(divisible5.pointer, buzz) +] + +void main() { + byte i, s + bool printed + for i,1,to,100 { + printed = false + for s,0,until,stages.length { + if call(stages[s].predicate, i) { + printed = true + putstrz(stages[js].text) + } + } + if not(printed) { + putword(i) + } + putchar(' ') + } +} + diff --git a/include/stdlib_6502.mfk b/include/stdlib_6502.mfk index 73efc8dd..2b69ee94 100644 --- a/include/stdlib_6502.mfk +++ b/include/stdlib_6502.mfk @@ -47,4 +47,8 @@ macro asm void panic() { ? JSR _panic } -array __constant8 = [8] +const array __constant8 = [8] + +#if ZPREG_SIZE < 4 +const array call = [2] +#endif diff --git a/include/stdlib_i80.mfk b/include/stdlib_i80.mfk index 7be04f57..cb85fcf8 100644 --- a/include/stdlib_i80.mfk +++ b/include/stdlib_i80.mfk @@ -53,3 +53,8 @@ _lo_nibble_to_hex_lbl: macro asm void panic() { ? CALL _panic } + +noinline asm word call(word de) { + PUSH DE + RET +} diff --git a/include/zp_reg.mfk b/include/zp_reg.mfk index 7eb77927..33e77afe 100644 --- a/include/zp_reg.mfk +++ b/include/zp_reg.mfk @@ -96,3 +96,11 @@ asm word __div_u16u8u16u8() { } #endif + +#if ZPREG_SIZE >= 4 + +noinline asm word call(word ax) { + JMP ((__reg + 2)) +} + +#endif diff --git a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala index f9abfe93..c2a037a2 100644 --- a/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala +++ b/src/main/scala/millfork/assembly/mos/opt/ZeropageRegisterOptimizations.scala @@ -13,6 +13,7 @@ import millfork.error.FatalErrorReporting object ZeropageRegisterOptimizations { val functionsThatUsePseudoregisterAsInput: Map[String, Set[Int]] = Map( + "call" -> Set(2, 3), "__mul_u8u8u8" -> Set(0, 1), "__mod_u8u8u8u8" -> Set(0, 1), "__div_u8u8u8u8" -> Set(0, 1), diff --git a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala index 99480213..a68479f8 100644 --- a/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala +++ b/src/main/scala/millfork/assembly/z80/opt/ReverseFlowAnalyzer.scala @@ -183,13 +183,13 @@ object ReverseFlowAnalyzer { val cache = new FlowCache[ZLine, CpuImportance]("z80 reverse") - val readsA: Set[String] = Set("__mul_u8u8u8", "__mul_u16u8u16") + val readsA: Set[String] = Set("__mul_u8u8u8", "__mul_u16u8u16", "call") val readsB: Set[String] = Set("") val readsC: Set[String] = Set("") - val readsD: Set[String] = Set("__mul_u8u8u8","__mul_u16u8u16", "__divmod_u16u8u16u8") - val readsE: Set[String] = Set("__mul_u16u8u16") - val readsH: Set[String] = Set("__divmod_u16u8u16u8") - val readsL: Set[String] = Set("__divmod_u16u8u16u8") + val readsD: Set[String] = Set("__mul_u8u8u8","__mul_u16u8u16", "__divmod_u16u8u16u8", "call") + val readsE: Set[String] = Set("__mul_u16u8u16", "call") + val readsH: Set[String] = Set("__divmod_u16u8u16u8", "call") + val readsL: Set[String] = Set("__divmod_u16u8u16u8", "call") //noinspection RedundantNewCaseClass def analyze(f: NormalFunction, code: List[ZLine]): List[CpuImportance] = { diff --git a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala index 312480ea..713d5c0c 100644 --- a/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala @@ -260,7 +260,7 @@ object AbstractExpressionCompiler { val v = env.get[Type]("void") val w = env.get[Type]("word") - val t = expr match { + val t: Type = expr match { case LiteralExpression(_, size) => size match { case 1 => b @@ -372,6 +372,34 @@ object AbstractExpressionCompiler { case Some(List(x)) => toType(!x) case _ => bool } + case f@FunctionCallExpression("call", params) => + params match { + case List(fp) => + getExpressionType(env, log, fp) match { + case fpt@FunctionPointerType(_, _, _, Some(v), Some(r)) => + if (v.name != "void"){ + log.error(s"Invalid function pointer type: $fpt", fp.position) + } + r + case fpt => + log.error(s"Not a function pointer type: $fpt", fp.position) + v + } + case List(fp, pp) => + getExpressionType(env, log, fp) match { + case fpt@FunctionPointerType(_, _, _, Some(p), Some(r)) => + if (!getExpressionType(env, log, pp).isAssignableTo(p)){ + log.error(s"Invalid function pointer type: $fpt", fp.position) + } + r + case fpt => + log.error(s"Not a function pointer type: $fpt", fp.position) + v + } + case _ => + log.error("Invalid call(...) syntax; use either 1 or 2 arguments", f.position) + v + } case FunctionCallExpression("hi", _) => b case FunctionCallExpression("lo", _) => b case FunctionCallExpression("sin", params) => if (params.size < 2) b else getExpressionType(env, log, params(1)) diff --git a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala index c8c17a93..1d7f3e70 100644 --- a/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/mos/MosExpressionCompiler.scala @@ -514,6 +514,12 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { compile(ctx, expr, Some(p -> env.get[Variable]("__reg.loword")), BranchSpec.None) } + def compileToZReg2(ctx: CompilationContext, expr: Expression): List[AssemblyLine] = { + val env = ctx.env + val p = env.get[Type]("pointer") + compile(ctx, expr, Some(p -> env.get[Variable]("__reg.b2b3")), BranchSpec.None) + } + def getPhysicalPointerForDeref(ctx: CompilationContext, pointerExpression: Expression): (List[AssemblyLine], Constant, AddrMode.Value) = { pointerExpression match { case VariableExpression(name) => @@ -1162,6 +1168,36 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { var zeroExtend = false var resultVariable = "" val calculate: List[AssemblyLine] = name match { + case "call" => + params match { + case List(fp) => + getExpressionType(ctx, fp) match { + case FunctionPointerType(_, _, _, _, Some(v)) if (v.name == "void") => + compileToZReg2(ctx, fp) :+ AssemblyLine.absolute(JSR, env.get[ThingInMemory]("call")) + case _ => + ctx.log.error("Not a function pointer", fp.position) + compile(ctx, fp, None, BranchSpec.None) + } + case List(fp, param) => + getExpressionType(ctx, fp) match { + case FunctionPointerType(_, _, _, Some(pt), Some(v)) => + if (pt.size != 1) { + ctx.log.error("Invalid parameter type", param.position) + compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None) + } else if (getExpressionType(ctx, param).isAssignableTo(pt)) { + compileToA(ctx, param) ++ preserveRegisterIfNeeded(ctx, MosRegister.A, compileToZReg2(ctx, fp)) :+ AssemblyLine.absolute(JSR, env.get[ThingInMemory]("call")) + } else { + ctx.log.error("Invalid parameter type", param.position) + compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None) + } + case _ => + ctx.log.error("Not a function pointer", fp.position) + compile(ctx, fp, None, BranchSpec.None) ++ compile(ctx, param, None, BranchSpec.None) + } + case _ => + ctx.log.error("Invalid call syntax", f.position) + Nil + } case "not" => assertBool(ctx, "not", params, 1) compile(ctx, params.head, exprTypeAndVariable, branches.flip) @@ -2062,7 +2098,7 @@ object MosExpressionCompiler extends AbstractExpressionCompiler[AssemblyLine] { val reg = env.get[ThingInMemory]("__reg.loword") (addr, addrSource) match { case (MemoryAddressConstant(th1: Thing), MemoryAddressConstant(th2: Thing)) - if (th1.name == "__reg.loword" || th1.name == "__reg") && (th2.name == "__reg.loword" || th2.name == "__reg") => + if (th1.name == "__reg.loword" || th1.name == "__reg.b2b3" || th1.name == "__reg") && (th2.name == "__reg.loword" || th2.name == "__reg.b2b3" || th2.name == "__reg") => (MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3), MosExpressionCompiler.changesZpreg(prepareSource, 2) || MosExpressionCompiler.changesZpreg(prepareSource, 3)) match { case (_, false) => diff --git a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala index 47e9fc45..0f3830d6 100644 --- a/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala +++ b/src/main/scala/millfork/compiler/z80/Z80ExpressionCompiler.scala @@ -738,6 +738,37 @@ object Z80ExpressionCompiler extends AbstractExpressionCompiler[ZLine] { }) case f@FunctionCallExpression(name, params) => name match { + case "call" => + val callLine = ZLine(CALL, NoRegisters, env.get[ThingInMemory]("call").toAddress) + params match { + case List(fp) => + getExpressionType(ctx, fp) match { + case FunctionPointerType(_, _, _, _, Some(v)) if (v.name == "void") => + compileToDE(ctx, fp) :+ callLine + case _ => + ctx.log.error("Not a function pointer", fp.position) + compile(ctx, fp, ZExpressionTarget.NOTHING, BranchSpec.None) + } + case List(fp, param) => + getExpressionType(ctx, fp) match { + case FunctionPointerType(_, _, _, Some(pt), Some(v)) => + if (pt.size != 1) { + ctx.log.error("Invalid parameter type", param.position) + compileToHL(ctx, fp) ++ compile(ctx, param, ZExpressionTarget.NOTHING) + } else if (getExpressionType(ctx, param).isAssignableTo(pt)) { + compileToDE(ctx, fp) ++ stashDEIfChanged(ctx, compileToA(ctx, param)) :+ callLine + } else { + ctx.log.error("Invalid parameter type", param.position) + compileToHL(ctx, fp) ++ compile(ctx, param, ZExpressionTarget.NOTHING) + } + case _ => + ctx.log.error("Not a function pointer", fp.position) + compile(ctx, fp, ZExpressionTarget.NOTHING) ++ compile(ctx, param, ZExpressionTarget.NOTHING) + } + case _ => + ctx.log.error("Invalid call syntax", f.position) + Nil + } case "not" => assertBool(ctx, "not", params, 1) compile(ctx, params.head, target, branches.flip) diff --git a/src/main/scala/millfork/env/Environment.scala b/src/main/scala/millfork/env/Environment.scala index 029eef59..e9fda4f7 100644 --- a/src/main/scala/millfork/env/Environment.scala +++ b/src/main/scala/millfork/env/Environment.scala @@ -284,6 +284,12 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } def get[T <: Thing : Manifest](name: String, position: Option[Position] = None): T = { + if (name.startsWith("function.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) { + val tokens = name.stripPrefix("function.").split("\\.to\\.", 2) + if (tokens.length == 2) { + return FunctionPointerType(name, tokens(0), tokens(1), maybeGet[Type](tokens(0)), maybeGet[Type](tokens(1))).asInstanceOf[T] + } + } if (name.startsWith("pointer.") && implicitly[Manifest[T]].runtimeClass.isAssignableFrom(classOf[PointerType])) { val targetName = name.stripPrefix("pointer.") val target = maybeGet[VariableType](targetName) @@ -1147,9 +1153,16 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa } } + private def getFunctionPointerType(f: FunctionInMemory) = f.params.types match { + case List() => + get[Type]("function.void.to." + f.returnType.name) + case p :: _ => // TODO: this only handles one type though! + get[Type]("function." + p.name + ".to." + f.returnType.name) + } + private def registerAddressConstant(thing: ThingInMemory, position: Option[Position], options: CompilationOptions, targetType: Option[Type]): Unit = { + val b = get[Type]("byte") if (!thing.zeropage && options.flag(CompilationFlag.LUnixRelocatableCode)) { - val b = get[Type]("byte") val w = get[Type]("word") val relocatable = UninitializedMemoryVariable(thing.name + ".addr", w, VariableAllocationMethod.Static, None, defaultVariableAlignment(options, 2), isVolatile = false) val addr = relocatable.toAddress @@ -1166,20 +1179,36 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa addThing(ConstantThing(thing.name + ".rawaddr", rawaddr, get[Type]("pointer")), position) addThing(ConstantThing(thing.name + ".rawaddr.hi", rawaddr.hiByte, get[Type]("byte")), position) addThing(ConstantThing(thing.name + ".rawaddr.lo", rawaddr.loByte, get[Type]("byte")), position) + thing match { + case f: FunctionInMemory if f.canBePointedTo => + val typedPointer = RelativeVariable(thing.name + ".pointer", addr, getFunctionPointerType(f), zeropage = false, None, isVolatile = false) + addThing(typedPointer, position) + addThing(RelativeVariable(thing.name + ".pointer.hi", addr + 1, b, zeropage = false, None, isVolatile = false), position) + addThing(RelativeVariable(thing.name + ".pointer.lo", addr, b, zeropage = false, None, isVolatile = false), position) + case _ => + } } else { val addr = thing.toAddress addThing(ConstantThing(thing.name + ".addr", addr, get[Type]("pointer")), position) - addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, get[Type]("byte")), position) - addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, get[Type]("byte")), position) + addThing(ConstantThing(thing.name + ".addr.hi", addr.hiByte, b), position) + addThing(ConstantThing(thing.name + ".addr.lo", addr.loByte, b), position) addThing(ConstantThing(thing.name + ".rawaddr", addr, get[Type]("pointer")), position) - addThing(ConstantThing(thing.name + ".rawaddr.hi", addr.hiByte, get[Type]("byte")), position) - addThing(ConstantThing(thing.name + ".rawaddr.lo", addr.loByte, get[Type]("byte")), position) + addThing(ConstantThing(thing.name + ".rawaddr.hi", addr.hiByte, b), position) + addThing(ConstantThing(thing.name + ".rawaddr.lo", addr.loByte, b), position) targetType.foreach { tt => val pointerType = PointerType("pointer." + tt.name, tt.name, Some(tt)) val typedPointer = RelativeVariable(thing.name + ".pointer", addr, pointerType, zeropage = false, None, isVolatile = false) addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position) - addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, get[Type]("byte")), position) - addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, get[Type]("byte")), position) + addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, b), position) + addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, b), position) + } + thing match { + case f: FunctionInMemory if f.canBePointedTo => + val pointerType = getFunctionPointerType(f) + addThing(ConstantThing(thing.name + ".pointer", addr, pointerType), position) + addThing(ConstantThing(thing.name + ".pointer.hi", addr.hiByte, b), position) + addThing(ConstantThing(thing.name + ".pointer.lo", addr.loByte, b), position) + case _ => } } } @@ -1698,6 +1727,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa if (function.params.length != actualParams.length) { log.error(s"Invalid number of parameters for function `$name`", actualParams.headOption.flatMap(_._2.position)) } + if (name == "call") return Some(function) function.params match { case NormalParamSignature(params) => function.params.types.zip(actualParams).zip(params).foreach { case ((required, (actual, expr)), m) => diff --git a/src/main/scala/millfork/env/Thing.scala b/src/main/scala/millfork/env/Thing.scala index c9c605c1..049e6d6f 100644 --- a/src/main/scala/millfork/env/Thing.scala +++ b/src/main/scala/millfork/env/Thing.scala @@ -82,6 +82,14 @@ case class PointerType(name: String, targetName: String, var target: Option[Type override def pointerTargetName: String = targetName } +case class FunctionPointerType(name: String, paramTypeName:String, returnTypeName: String, var paramType: Option[Type], var returnType: Option[Type]) extends VariableType { + def size = 2 + + override def isSigned: Boolean = false + + override def isPointy: Boolean = false +} + case object NullType extends VariableType { override def size: Int = 2 @@ -91,9 +99,9 @@ case object NullType extends VariableType { override def isPointy: Boolean = true - override def isSubtypeOf(other: Type): Boolean = this == other || other.isPointy && other.size == 2 + override def isSubtypeOf(other: Type): Boolean = this == other || (other.isPointy || other.isInstanceOf[FunctionPointerType]) && other.size == 2 - override def isAssignableTo(targetType: Type): Boolean = this == targetType || targetType.isPointy && targetType.size == 2 + override def isAssignableTo(targetType: Type): Boolean = this == targetType || (targetType.isPointy || targetType.isInstanceOf[FunctionPointerType]) && targetType.size == 2 } case class EnumType(name: String, count: Option[Int]) extends VariableType { @@ -340,6 +348,8 @@ sealed trait MangledFunction extends CallableThing { def params: ParamSignature def interrupt: Boolean + + def canBePointedTo: Boolean } case class EmptyFunction(name: String, @@ -348,6 +358,8 @@ case class EmptyFunction(name: String, override def params = EmptyFunctionParamSignature(paramType) override def interrupt = false + + override def canBePointedTo: Boolean = false } case class MacroFunction(name: String, @@ -356,6 +368,8 @@ case class MacroFunction(name: String, environment: Environment, code: List[ExecutableStatement]) extends MangledFunction { override def interrupt = false + + override def canBePointedTo: Boolean = false } sealed trait FunctionInMemory extends MangledFunction with ThingInMemory { @@ -381,6 +395,8 @@ case class ExternFunction(name: String, override def zeropage: Boolean = false override def isVolatile: Boolean = false + + override def canBePointedTo: Boolean = !interrupt && returnType.size <= 2 && params.canBePointedTo && name !="call" } case class NormalFunction(name: String, @@ -402,6 +418,8 @@ case class NormalFunction(name: String, override def zeropage: Boolean = false override def isVolatile: Boolean = false + + override def canBePointedTo: Boolean = !interrupt && returnType.size <= 2 && params.canBePointedTo && name !="call" } case class ConstantThing(name: String, value: Constant, typ: Type) extends TypedThing with VariableLikeThing with IndexableThing { @@ -416,12 +434,16 @@ trait ParamSignature { def types: List[Type] def length: Int + + def canBePointedTo: Boolean } case class NormalParamSignature(params: List[VariableInMemory]) extends ParamSignature { override def length: Int = params.length override def types: List[Type] = params.map(_.typ) + + def canBePointedTo: Boolean = params.size <= 1 && params.forall(_.typ.size.<=(1)) } sealed trait ParamPassingConvention { @@ -464,17 +486,27 @@ object AssemblyParameterPassingBehaviour extends Enumeration { val Copy, ByReference, ByConstant = Value } -case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value) +case class AssemblyParam(typ: Type, variable: TypedThing, behaviour: AssemblyParameterPassingBehaviour.Value) { + def canBePointedTo: Boolean = behaviour == AssemblyParameterPassingBehaviour.Copy && (variable match { + case RegisterVariable(MosRegister.A | MosRegister.AX, _) => true + case ZRegisterVariable(ZRegister.A | ZRegister.HL, _) => true + case _ => false + }) +} case class AssemblyParamSignature(params: List[AssemblyParam]) extends ParamSignature { override def length: Int = params.length override def types: List[Type] = params.map(_.typ) + + def canBePointedTo: Boolean = params.size <= 1 && params.forall(_.canBePointedTo) } case class EmptyFunctionParamSignature(paramType: Type) extends ParamSignature { override def length: Int = 1 override def types: List[Type] = List(paramType) + + def canBePointedTo: Boolean = false } \ No newline at end of file diff --git a/src/main/scala/millfork/node/CallGraph.scala b/src/main/scala/millfork/node/CallGraph.scala index f34c1c0e..0022d963 100644 --- a/src/main/scala/millfork/node/CallGraph.scala +++ b/src/main/scala/millfork/node/CallGraph.scala @@ -59,12 +59,16 @@ abstract class CallGraph(program: Program, log: Logger) { case s: SumExpression => s.expressions.foreach(expr => add(currentFunction, callingFunctions, expr._2)) case x: VariableExpression => - val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") + val varName = x.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr").stripSuffix(".pointer") everCalledFunctions += varName + entryPoints += varName // TODO: figure out how to interpret pointed-to functions case i: IndexedExpression => - val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr") + val varName = i.name.stripSuffix(".hi").stripSuffix(".lo").stripSuffix(".addr").stripSuffix(".pointer") everCalledFunctions += varName + entryPoints += varName add(currentFunction, callingFunctions, i.index) + case i: DerefExpression => + add(currentFunction, callingFunctions, i.inner) case i: DerefDebuggingExpression => add(currentFunction, callingFunctions, i.inner) case IndirectFieldExpression(root, firstIndices, fields) => diff --git a/src/main/scala/millfork/output/AbstractInliningCalculator.scala b/src/main/scala/millfork/output/AbstractInliningCalculator.scala index d59b8b6a..92948c7f 100644 --- a/src/main/scala/millfork/output/AbstractInliningCalculator.scala +++ b/src/main/scala/millfork/output/AbstractInliningCalculator.scala @@ -77,6 +77,9 @@ abstract class AbstractInliningCalculator[T <: AbstractCode] { case s: Statement => getAllCalledFunctions(s.getAllExpressions) case s: VariableExpression => Set( s.name, + s.name.stripSuffix(".pointer"), + s.name.stripSuffix(".pointer.lo"), + s.name.stripSuffix(".pointer.hi"), s.name.stripSuffix(".addr"), s.name.stripSuffix(".hi"), s.name.stripSuffix(".lo"), diff --git a/src/test/scala/millfork/test/FunctionPointerSuite.scala b/src/test/scala/millfork/test/FunctionPointerSuite.scala new file mode 100644 index 00000000..7b4a21c6 --- /dev/null +++ b/src/test/scala/millfork/test/FunctionPointerSuite.scala @@ -0,0 +1,102 @@ +package millfork.test + +import millfork.Cpu +import millfork.test.emu.{EmuCrossPlatformBenchmarkRun, EmuUnoptimizedCrossPlatformRun, ShouldNotCompile} +import org.scalatest.{AppendedClues, FunSuite, Matchers} + +/** + * @author Karol Stasiak + */ +class FunctionPointerSuite extends FunSuite with Matchers with AppendedClues{ + + test("Function pointers 1") { + EmuCrossPlatformBenchmarkRun(Cpu.Mos, Cpu.Cmos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | + | byte output @$c000 + | void f1() { + | output = 100 + | } + | + | void main() { + | function.void.to.void p1 + | p1 = f1.pointer + | call(p1) + | } + | + """.stripMargin) { m => + m.readByte(0xc000) should equal(100) + } + } + + test("Function pointers 2") { + EmuUnoptimizedCrossPlatformRun (Cpu.Mos, Cpu.Cmos, Cpu.Z80, Cpu.Intel8080, Cpu.Sharp)( + """ + | const byte COUNT = 128 + | array output0[COUNT] @$c000 + | array output1[COUNT] @$c100 + | array output2[COUNT] @$c200 + | array output3[COUNT] @$c300 + | + | void tabulate(pointer target, function.byte.to.byte f) { + | byte i + | for i,0,until,COUNT { + | target[i] = call(f, i) + | } + | } + | + | byte double(byte x) = x*2 + | byte negate(byte x) = 0-x + | byte zero(byte x) = 0 + | byte id(byte x) = x + | + | void main() { + | tabulate(output0, zero.pointer) + | tabulate(output1, id.pointer) + | tabulate(output2, double.pointer) + | tabulate(output3, negate.pointer) + | } + | + """.stripMargin) { m => + for (i <- 0 until 0x80) { + m.readByte(0xc000 + i) should equal(0) withClue ("zero " + i) + m.readByte(0xc100 + i) should equal(i) withClue ("id " + i) + m.readByte(0xc200 + i) should equal(i * 2) withClue ("double " + i) + m.readByte(0xc300 + i) should equal((256 - i) & 0xff) withClue ("negate " + i) + } + } + } + + test("Function pointers: invalid types") { + ShouldNotCompile( + """ + |void main() { + | call(main.pointer, 1) + |} + |""".stripMargin) + ShouldNotCompile( + """ + |void f(byte a) = 0 + |void main() { + | call(f.pointer) + |} + |""".stripMargin) + ShouldNotCompile( + """ + |enum e {} + |void f(e a) = 0 + |void main() { + | call(f.pointer, 0) + |} + |""".stripMargin) + ShouldNotCompile( + """ + |enum e {} + |void f(byte a) = 0 + |void main() { + | call(f.pointer, e(7)) + |} + |""".stripMargin) + } + +} diff --git a/src/test/scala/millfork/test/emu/EmuRun.scala b/src/test/scala/millfork/test/emu/EmuRun.scala index fd85cd53..397a919d 100644 --- a/src/test/scala/millfork/test/emu/EmuRun.scala +++ b/src/test/scala/millfork/test/emu/EmuRun.scala @@ -156,6 +156,7 @@ class EmuRun(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimization], if (native16 && platform.cpu != millfork.Cpu.Sixteen) throw new IllegalStateException var effectiveSource = source if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" + if (source.contains("call(")) effectiveSource += "\nnoinline asm word call(word ax) {\nJMP ((__reg.b2b3))\n}\n" if (native16) effectiveSource += """ | diff --git a/src/test/scala/millfork/test/emu/EmuZ80Run.scala b/src/test/scala/millfork/test/emu/EmuZ80Run.scala index aa0a7dcc..81f4bd85 100644 --- a/src/test/scala/millfork/test/emu/EmuZ80Run.scala +++ b/src/test/scala/millfork/test/emu/EmuZ80Run.scala @@ -92,6 +92,11 @@ class EmuZ80Run(cpu: millfork.Cpu.Value, nodeOptimizations: List[NodeOptimizatio log.verbosity = 999 var effectiveSource = source if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" + if (source.contains("call(")) { + if (options.flag(CompilationFlag.UseIntelSyntaxForInput)) + effectiveSource += "\nnoinline asm word call(word de) {\npush d\nret\n}\n" + else effectiveSource += "\nnoinline asm word call(word de) {\npush de\nret\n}\n" + } log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq)) val PreprocessingResult(preprocessedSource, features, pragmas) = Preprocessor.preprocessForTest(options, effectiveSource) // tests use Intel syntax only when forced to: diff --git a/src/test/scala/millfork/test/emu/ShouldNotCompile.scala b/src/test/scala/millfork/test/emu/ShouldNotCompile.scala index 1089c62a..c722e5f9 100644 --- a/src/test/scala/millfork/test/emu/ShouldNotCompile.scala +++ b/src/test/scala/millfork/test/emu/ShouldNotCompile.scala @@ -8,7 +8,7 @@ import millfork.compiler.{CompilationContext, LabelGenerator} import millfork.compiler.mos.MosCompiler import millfork.env.{Environment, InitializedArray, InitializedMemoryVariable, NormalFunction} import millfork.node.StandardCallGraph -import millfork.parser.{MosParser, PreprocessingResult, Preprocessor} +import millfork.parser.{MosParser, PreprocessingResult, Preprocessor, Z80Parser} import millfork._ import millfork.compiler.m6809.M6809Compiler import millfork.compiler.z80.Z80Compiler @@ -36,11 +36,27 @@ object ShouldNotCompile extends Matchers { log.verbosity = 999 var effectiveSource = source if (!source.contains("_panic")) effectiveSource += "\n void _panic(){while(true){}}" + if (source.contains("call(")) { + platform.cpuFamily match { + case CpuFamily.M6502 => + effectiveSource += "\nnoinline asm word call(word ax) {\nJMP ((__reg.b2b3))\n}\n" + case CpuFamily.I80 => + if (options.flag(CompilationFlag.UseIntelSyntaxForInput)) + effectiveSource += "\nnoinline asm word call(word de) {\npush d\nret\n}\n" + else effectiveSource += "\nnoinline asm word call(word de) {\npush de\nret\n}\n" + } + } if (source.contains("import zp_reg")) effectiveSource += Files.readAllLines(Paths.get("include/zp_reg.mfk"), StandardCharsets.US_ASCII).asScala.mkString("\n", "\n", "") log.setSource(Some(effectiveSource.linesIterator.toIndexedSeq)) val PreprocessingResult(preprocessedSource, features, _) = Preprocessor.preprocessForTest(options, effectiveSource) - val parserF = MosParser("", preprocessedSource, "", options, features) + val parserF = + platform.cpuFamily match { + case CpuFamily.M6502 => + MosParser("", preprocessedSource, "", options, features) + case CpuFamily.I80 => + Z80Parser("", preprocessedSource, "", options, features, options.flag(CompilationFlag.UseIntelSyntaxForInput)) + } parserF.toAst match { case Success(program, _) => log.assertNoErrors("Parse failed")